diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index 8b2cae7a022808b570958c7f49a9261a62833579..d88b73aee97585abd2faa5a6800b1b3922dbe1c6 100755 --- a/build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d8b08406547ecaf4b08409b5c8a5144ac0f91faac6c28dcfa6938dd75470db34 -size 70296128 +oid sha256:35967133ffbd0cac32aafc9e70f441264b1f41710f4f86d68723d2eb9a59cfe8 +size 85717704 diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/cutlass.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/marlin.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch24-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index 70d8649213103e490653df66a4815891e5955055..f0ce7f35a52c01d112009a678af21b6d5f33b6ea 100755 --- a/build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d1fec120a5de02ea58eb455e5f6483ce15da4672c356982bd4ac070864755e28 -size 86065792 +oid sha256:c0abd1636906a69e8cf7d85fdfc7b99b6b4f4cc3d753431ad3d49ba674238c27 +size 105267936 diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/cutlass.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/marlin.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/scalar_type.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch24-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index 198d805b6aeeabe02443f6e2d517e36bc5300b5d..bede91a0ddfcd855c5ab978ef853a751715077dd 100755 --- a/build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1807b0b92ecf8fbb62bdf33fc4ba57dbabc177082c3e0e24b1ac7cd085462ae0 -size 89584848 +oid sha256:a458d5efc51f80028811707ee7b9fadb00f3bfc49917c8377188c286c4bd8e12 +size 109249352 diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/cutlass.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/marlin.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch24-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index 7b243a20687f55bd08362a0269da9a4b651f6c00..c2edd22ce154ba6ac696fa5a9d6868c13ba7e408 100755 --- a/build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bfb228c577303d6d981b88ab6955447786f88e4a551b42b3ab13225fb96aa81b -size 70283536 +oid sha256:734f235fc2749269910ee4e988da205a9442edf73c0f9b3ef41fff100bc66707 +size 85709024 diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/cutlass.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/marlin.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch24-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/__init__.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch24-cxx98-cu121-x86_64-linux/quantization/__init__.py +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/_ops.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch24-cxx98-cu121-x86_64-linux/quantization/_ops.py +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch24-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index c0d9eb05213e8d219e9dcbe119de8b50e5b9c490..e12c8d6193ae8bd98d37a51bebaafacd7b81c95c 100755 --- a/build/torch24-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3286c3c772cef25a780bd7510251bbed8cf2dcc460df9bba64f7a8dbefc71a1a -size 86060232 +oid sha256:516d2dbd3669ce8b9fd78f84413747bce207223f9987cbdb68e042c8ab3688ac +size 105258480 diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/cutlass.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/marlin.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/scalar_type.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/__init__.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch24-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/__init__.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch24-cxx98-cu124-x86_64-linux/quantization/__init__.py +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/_ops.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch24-cxx98-cu124-x86_64-linux/quantization/_ops.py +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch24-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index a49d853ee227d0304f5a3073366524ef42c92b3d..af7bc1313355213be6278bf73ed3bc46031f1d31 100755 --- a/build/torch24-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:57c4f62051318e944e4f835f214e1504e960df161460b153f55be2909bee60d8 -size 89571096 +oid sha256:b0840d3a079bbf070d21282020386b2fc121da9894be0fa88ffcb6680d92bf0f +size 109243600 diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/cutlass.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/marlin.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/__init__.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch24-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index 1fc98b35489370a4212fe4536a9acf316f1f8efa..6a8a56ad608290ae9c942ccc440f889cb517d81b 100755 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4d77e7f7252bda65b70bff988a6e1883ce523d6a2c4d33674d000165204903b7 -size 70296128 +oid sha256:6f21a34510be03ab2a0ef92fed8db8aae8170d257922fbe6b2917a2b21b8df07 +size 85717704 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/marlin.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index 271eec7ab9405711e7f195c31ddd041d2e37b895..f75c04707fddddd9b78d9be7ef7fdf857d6f550f 100755 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:015c401d9f0498c0b79651e43788ce8daa1f344eda1dd765a9c863aa8071c3c0 -size 86065792 +oid sha256:7451eaf399d27e0f7fbac108964d862b97b8f12a5fb6decdf9a955874aa95548 +size 105267936 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/marlin.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index ceaea9800c5ddcebda2a3cfc81fc3007ea8eef10..36067123090f6135455fb5b820a0d5acde927fca 100755 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f6f77dea7aebf51e168fa341a0f6a47628c5b15fe8d33805f32203c636ada4e5 -size 89584848 +oid sha256:4f539485ab9a338fe8d1ed5de27bc9b0e6295c2c469910f0948dfa69ef629baf +size 109249352 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/marlin.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index eafddc97647c3e611108b8b7858999e994a29919..70409925396de0cbefabcb89dd68d1fd8862fadc 100755 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ded84bcaf0bf7fdb8dc36d86e78319acff91c8a33501c8eae286f337661f5bbd -size 70283536 +oid sha256:25d0f8374e7023760dfedf2f99fb7d56c22f02f0f4b82634e6166515f111fcc2 +size 85709024 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/marlin.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index e7ade1ed65096fe68f31ce2d2339182da5467c89..573613091ecbe6d98d33178c07c5b1c14d9f4d07 100755 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cfb729ac5ab820b7da442dab20258a53c83458f5299dfdb1f49fb5008347a917 -size 86060232 +oid sha256:e153ae86e155d6fc792f08b4986e899467299abe50e3d39519de4b7e7198a5fa +size 105258480 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/marlin.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py index fcf21efe93dc773645fad1406833d98282594b52..4f98b4e5768653e1e6ad118c5a219c39b6f4b7d4 100644 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py @@ -1,150 +1,30 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - ops = torch.ops._quantization - except ImportError: - raise e - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - -def cutlass_scaled_mm(a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) - assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) - assert bias is None or bias.shape[0] == b.shape[ - 1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - #if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert (input.ndim == 2) - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - #out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), - device=input.device, - dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant( - output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert (scale.numel() == 1 or num_token_padding is None) - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty((input.numel() // input.shape[-1], 1), - device=input.device, - dtype=torch.float32) - input_azp = None if symmetric else torch.empty_like(input_scales, - dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, - input_azp) - return output, input_scales, input_azp - -# fp8 marlin -def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: int, size_n: int, - size_k: int) -> torch.Tensor: - return ops.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace, - num_bits, size_m, size_n, size_k) +from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant +from .cutlass import ( + cutlass_scaled_mm_supports_fp8, + cutlass_scaled_mm, + cutlass_scaled_mm_azp, +) +from .marlin import ( + awq_marlin_repack, + fp8_marlin_gemm, + gptq_marlin_gemm, + gptq_marlin_repack, + gptq_marlin_24_gemm, + marlin_qqq_gemm, + marlin_gemm, +) + +__all__ = [ + "awq_marlin_repack", + "cutlass_scaled_mm", + "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_fp8", + "fp8_marlin_gemm", + "gptq_marlin_24_gemm", + "gptq_marlin_gemm", + "gptq_marlin_repack", + "marlin_gemm", + "marlin_qqq_gemm", + "scaled_fp8_quant", + "scaled_int8_quant", +] diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py index 29f0ec67c5ae23385ac0d85e16ac8997f73ef96d..3c51cae9fb77fa474a11555bd922681ee490a427 100644 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py @@ -1,3 +1,9 @@ import torch from . import _quantization_0_0_1 ops = torch.ops._quantization_0_0_1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_quantization_0_0_1::{op_name}" diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so b/build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so index 029a859319c687f8ead8ceea7fb50aeed786157c..4812f245bb85d6fd30b6998ca6d72d6d98532b1f 100755 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0_0_1.abi3.so @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:314801e406b4ff5eac0b4b59ef617ace82079634dbfc476ba6ce73daadc47dbe -size 89571096 +oid sha256:e8016e26ae454e2c820104d11b44bd30adaffc112b70043daeb58bfb2fab9f1c +size 109243600 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ba30bac87979a307fc5061a46f5d2cbf0efbf9 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +# fp8 +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + num_token_padding: Optional[int] = None, + scale_ub: Optional[torch.Tensor] = None, + use_per_token_if_dynamic: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to FP8 and return quantized tensor and scale. + + This function supports both static and dynamic quantization: If you + provide the scale, it will use static scaling and if you omit it, + the scale will be determined dynamically. The function also allows + optional padding of the output tensors for downstream kernels that + will benefit from padding. + + Args: + input: The input tensor to be quantized to FP8 + scale: Optional scaling factor for the FP8 quantization + scale_ub: Optional upper bound for scaling factor in dynamic + per token case + num_token_padding: If specified, pad the first dimension + of the output to at least this value. + use_per_token_if_dynamic: Whether to do per_tensor or per_token + in the dynamic quantization case. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + scaling factor. + """ + # This code assumes batch_dim and num_tokens are flattened + assert input.ndim == 2 + shape: Union[Tuple[int, int], torch.Size] = input.shape + # For rocm, the output fp8 dtype is torch.float_e3m3fnuz + # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ + # if current_platform.is_rocm() else torch.float8_e4m3fn + out_dtype = torch.float8_e4m3fn + if num_token_padding: + shape = (max(num_token_padding, input.shape[0]), shape[1]) + output = torch.empty(shape, device=input.device, dtype=out_dtype) + + if scale is None: + if use_per_token_if_dynamic: + scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + else: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + # num_token_padding not implemented for this case + assert scale.numel() == 1 or num_token_padding is None + ops.static_scaled_fp8_quant(output, input, scale) + + return output, scale + + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) + if scale is not None: + # static-per-tensor quantization. + assert symmetric == ( + azp is None + ), "azp must only be provided for asymmetric quantization." + ops.static_scaled_int8_quant(output, input, scale, azp) + return output, scale, azp + + # dynamic-per-token quantization. + input_scales = torch.empty( + (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 + ) + input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + return output, input_scales, input_azp diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..c378b846d0c59de183a321fcad4b403c47b3d750 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/cutlass.py @@ -0,0 +1,75 @@ +from typing import Optional + +import torch + +try: + from ._ops import ops +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + except ImportError: + raise e + + +def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + + +def cutlass_scaled_mm( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype + + m = a.shape[0] + n = b.shape[1] + + # if current_platform.is_rocm(): + # triton_scaled_mm_module = importlib.import_module( + # "vllm.model_executor.layers.quantization.compressed_tensors." + # "triton_scaled_mm") + # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm + # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + + return out + + +def cutlass_scaled_mm_azp( + a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: torch.dtype, + azp_adj: torch.Tensor, + azp: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + :param azp_adj: In the per-tensor case, this should include the azp. + Always per-channel. + :param azp: Only set in the per-token case. Per-token if set. + """ + assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 + assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 + assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype + assert azp is None or azp.numel() == a.shape[0] + + m = a.shape[0] + n = b.shape[1] + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + + ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) + return out diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/marlin.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..44d5d28a2fb67af955c017af3cf1403feeecbd32 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/marlin.py @@ -0,0 +1,208 @@ +from typing import TYPE_CHECKING + +import torch + +# neuron has torch version that doesn't even have impl_abstract +if TYPE_CHECKING: + def register_fake(fn): + return lambda name: fn +else: + try: + from torch.library import register_fake + except ImportError: + from torch.library import impl_abstract as register_fake + +try: + from ._ops import ops, add_op_namespace_prefix +except ImportError as e: + # Fallback for local development. + try: + import _quantization + + ops = torch.ops._quantization + + def add_op_namespace_prefix(op_name: str): + return f"_quantization::{op_name}" + except ImportError: + raise e + + +from .scalar_type import ScalarType + + +# fp8 marlin +def fp8_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + num_bits: int, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.fp8_marlin_gemm( + a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k + ) + + +# gptq_marlin +def gptq_marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + return ops.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + has_zp, + use_fp32_reduce, + is_zp_float, + ) + + +# gptq_marlin +def gptq_marlin_repack( + b_q_weight: torch.Tensor, + perm: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) + + +# gptq_marlin +def awq_marlin_repack( + b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) + + +# marlin +def marlin_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_gemm( + a, b_q_weight, b_scales, workspace, size_m, size_n, size_k + ) + + +# marlin_24 +def gptq_marlin_24_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_meta: torch.Tensor, + b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.gptq_marlin_24_gemm( + a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k + ) + + +# qqq ops +def marlin_qqq_gemm( + a: torch.Tensor, + b_q_weight: torch.Tensor, + s_tok: torch.Tensor, + s_ch: torch.Tensor, + s_group: torch.Tensor, + workspace: torch.Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> torch.Tensor: + return ops.marlin_qqq_gemm( + a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +# Fake ops + +if hasattr(ops, "gptq_marlin_24_gemm"): + @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) + def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + num_bits: int, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) + + @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) + def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_meta: torch.Tensor, b_scales: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) + def _gptq_marlin_gemm_fake(a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool, + has_zp: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) + + @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) + def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + s_tok: torch.Tensor, s_ch: torch.Tensor, + s_group: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) + + @register_fake(add_op_namespace_prefix("marlin_gemm")) + def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, workspace: torch.Tensor, + size_m: torch.SymInt, size_n: torch.SymInt, + size_k: torch.SymInt) -> torch.Tensor: + return torch.empty((size_m, size_n), + dtype=torch.float16, + device=a.device) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/scalar_type.py new file mode 100644 index 0000000000000000000000000000000000000000..9d711b0debcd8aaa343818edc9d6bbca20587d0a --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/scalar_type.py @@ -0,0 +1,330 @@ +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from typing import Optional, Union + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN + or self.nan_repr == NanRepr.NONE): + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = (max_exponent - exponent_bias + + exponent_bias_double) + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << + (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack('!d', struct.pack('!Q', double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert (self.size_bits < 64 or self.size_bits == 64 + and self.is_signed()), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert self.is_signed( + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack('!d', struct.pack('!Q', min_raw))[0] + else: + assert (not self.is_signed() or + self.size_bits <= 64), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, \ + f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and \ + not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = "float" + str(self.size_bits) + "_e" + str( + self.exponent) + "m" + str(self.mantissa) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, + nan_repr: NanRepr) -> 'ScalarType': + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert (mantissa > 0 and exponent > 0) + assert (nan_repr != NanRepr.IEEE_754), ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions") + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c94c38858a5cd6f02eb134d1a94b99a2b15566 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py @@ -0,0 +1,391 @@ +from typing import List, Optional, Tuple + +import numpy +import torch + +import quantization as ops +from quantization.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_24_TILE = 16 +GPTQ_MARLIN_24_MIN_THREAD_N = 128 +GPTQ_MARLIN_24_MIN_THREAD_K = 128 +GPTQ_MARLIN_24_MAX_PARALLEL = 64 + +GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_QQQ_TILE = 16 +MARLIN_QQQ_MIN_THREAD_N = 64 +MARLIN_QQQ_MIN_THREAD_K = 128 +MARLIN_QQQ_MAX_PARALLEL = 16 + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] +MARLIN_QQQ_SUPPORTED_SYM = [True] + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types( + has_zp: bool, device_capability: Optional[int] = None +): + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None, +) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = torch.cuda.get_device_capability() + device_capability = capability_tuple[0] * 10 + capability_tuple[1] + + supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + + if quant_type not in supported_types: + return ( + False, + f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).", + ) + if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return ( + False, + f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.", + ) + + return True, None + + +def check_marlin_supported( + quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None, +) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) + return cond + + +def verify_marlin_supported( + quant_type: ScalarType, group_size: int, has_zp: bool = False +) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq." + ) + + +def check_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape( + output_size_per_partition, input_size_per_partition, input_size, group_size + ) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace( + output_size_per_partition: int, device: torch.device +) -> torch.Tensor: + max_workspace_size = ( + output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N + ) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros( + max_workspace_size, dtype=torch.int, device=device, requires_grad=False + ) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks( + act_order: bool, group_size: int, is_row_parallel: bool +) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition,) + + output = ops.gptq_marlin_gemm( + reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..b269fa6a4cee316e8299ecc86c3e7594b336b499 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + + +def is_fp8_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, strategy: str = "tensor" +) -> None: + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=pack_fp8_to_int32(layer.weight), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 + ) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = ( + byte_tensor[:, 0].to(torch.int32) + | (byte_tensor[:, 1].to(torch.int32) << 8) + | (byte_tensor[:, 2].to(torch.int32) << 16) + | (byte_tensor[:, 3].to(torch.int32) << 24) + ) + + return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py @@ -0,0 +1,162 @@ +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points +from .quant_utils import ( + get_pack_factor, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert ( + out_features % min_thread_n == 0 + ), "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n + ) + + max_workspace_size = (out_features // min_thread_n) * max_parallel + + self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..927fa9016ba25f381c09d768db0c468066193a76 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,473 @@ +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from quantization.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = ( + dst_rows // group_x * group_x + + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4 + ) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError("Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16" + ) + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32" + ) + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1) + ) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( + m, k // 2 + ) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + ) + elif quadbits_per_meta_elem == 8: + meta = ( + meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28) + ) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols,) + ) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix" + ) + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device + ) + meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4 + ).view(-1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_( + 0, dense_offsets, sparse.view(torch.half).view(-1) + ) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" + ) + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j : j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False + ) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights( + q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm + ) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..cb58eb945836393c58c53f5c6d702d53861c33f9 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,125 @@ +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d97e03913fa5980e0be73b160088c8e4f5f49a52 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py @@ -0,0 +1,470 @@ +"""This file is used for /tests and /benchmarks""" + +from typing import List, Optional + +import numpy +import torch + +from quantization.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] + +# Note: this is a hack. We should update each model to register the +# stacked params and get it from there instead in a future PR. +# fused_name: List[shard_name] +FUSED_LAYER_NAME_MAPPING = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], +} + + +def pack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32( + w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 +): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + if proj_name in FUSED_LAYER_NAME_MAPPING: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS + ), f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, + size_n // pack_factor, + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor + ) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/ext-torch/utils/__init__.py b/ext-torch/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391