diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 960b41a51a36fa28484c3237780e1e9fe97a1428..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:b01b2ee690ad54303926af36debf43382f596fee6396822365b8ea88ae284eec -size 63485168 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/marlin.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu118-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 956896653f011b690aa5be26be99f1f8b47f0d63..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:8531abccfae1c201e83ad1279bdb092dd77a89e4dc7bc166bbe0625e2bbc6665 -size 64993040 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/marlin.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu121-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 05e4ac5b8c9a82e92b111cf551175b0ec0394675..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:290d7f6a8c742481b6655732a3c2d61567fb7bd24c69e23a98dd1eff94895db6 -size 67517912 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/marlin.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch25-cxx11-cu124-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 87cf60b43ce91ca35305cae2f46d1c16e23af824..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:893739f27c86e11a259df04fd24021f374256e434339f682baaf0c5fccfc3c8a -size 63468944 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/marlin.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu118-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 18492fae94fa90848d84630fe50a986861f36d84..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6692deb3c40c4bcee0ff28bf9d426c843f1a858e2a0bd12d92b5332c0adff4cf -size 64992856 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/marlin.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu121-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 0fb8d8c63bf07a0e4c92896e001b53ca536aad85..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f232faff4a6793b272825d95405d811f3ccbf8c2393e127d3f6772ff2441f165 -size 67519424 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/cutlass.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/marlin.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/__init__.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py b/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch25-cxx98-cu124-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/__init__.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/_ops.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 2f6a935abdc27f9aed50958347729676e1ea5b3a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:932f97f2f20cfe21d3f9b75494026e85fc7552c0aac43113ad1af6715a32482c -size 63484368 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0826a614b8629beb019928bf7d4fd71fc4b0d205 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:57e3f40d3bd58464cc005538c0d5376d64d7b6051f819f34c024f5b1940afb0f +size 155760312 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/cutlass.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/cutlass.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/marlin.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/marlin.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/platforms.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/scalar_type.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch26-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/__init__.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/_ops.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 72f45ae0f82e5d3663fac23cb28a8950828791a9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:36d47af9178b6187e6a14651f30b21f6038d31ed591688aba5c03b29b0bf88cc -size 67517488 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..bffbffec719ac5d13f4bfb4905185dd7ef12f07d --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b9f064f096a8814f4e1441a58dd10979e111c6dce94fbf11a381a0463150577a +size 159574104 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/cutlass.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/cutlass.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/marlin.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/marlin.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/platforms.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/scalar_type.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch26-cxx11-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/__init__.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/_ops.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 245c0b03cb43d563a640f30ab8b5487cdaba0871..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6b7a98b58caa01f436b3f089dfb62e2ec96a85ffdfad621f332701e0bc69b6a8 -size 68279984 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c817f86dfd48a9a768953758d2b5e07323eb6182 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1eb61783122b0f53ed36db827c8ea9cf1da094f22ca0433c575d18993565531f +size 160276600 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/cutlass.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/cutlass.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/marlin.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/marlin.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/platforms.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/scalar_type.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/scalar_type.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch26-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/__init__.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/_ops.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 36f29c974687b550b345056669afa9869ebb95a4..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:7eacdfbcfd25927283bc5c3653704a6c27da69a7c8ba7ef68f0691a66679054c -size 2878744 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/compressed_tensors.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/cutlass.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/marlin.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/scalar_type.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/__init__.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/quant_utils.py b/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch26-cxx11-rocm62-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/__init__.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/_ops.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 084e66466df36b9ee86b10d8e26fca41c0bcb642..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:3885094b146b702d2ac23780b0f102500c30f46e53cfaf42bf527d708485979a -size 63479104 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..1d18a8090b605c542bf0ea9d9b32b26ec0ebc1a3 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e6c2a3029d72467b8bf2fbfcf8e999683e58ab0a1c0eb4bc5fda1d92cfcc179d +size 155740048 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/cutlass.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/cutlass.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/marlin.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/marlin.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/platforms.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/scalar_type.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch26-cxx98-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/__init__.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/_ops.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 6dbc78a0c7392c047c3d57e94dbcf5207cff2f37..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:50d412800766f33cd706a8126bc47e8119001bcd1cadd96a8e408be114b3b1b7 -size 67509408 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e6c48077887cdf53edcf6ba6768607bece94ef5c --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2932ba43dd1ae4848b3077dada99be0088023a56e7b36bac9e863a1977249088 +size 159578496 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/cutlass.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/cutlass.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/marlin.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/marlin.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/platforms.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/scalar_type.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/scalar_type.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch26-cxx98-cu124-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/__init__.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/_ops.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 74c9cc5dab399184c8d4d3f4953656d77cf9dcba..0000000000000000000000000000000000000000 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:acf7c34de931bd4ab88454fad93df1afeccc84b8158a410fb9153ba42e5e82bc -size 68271904 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2fe72fdc0bcc9207a15e4539c151779e61e3ca8d --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1dcf28c2d636d90cd8af8bc7b44a3b7d5f5a1a599e7e1c03b06f3800d40f5a60 +size 160274448 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/compressed_tensors.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/cutlass.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/cutlass.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/marlin.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/marlin.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/platforms.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/scalar_type.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/scalar_type.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch26-cxx98-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/__init__.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 7795352d24ea348397c20d8a8670567d0bd94bf2..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:0a1650df5eeb0e1494932eec92425751998ea2181b86995671edb757ccf6aeb5 -size 63484856 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..678fe1639681e60dfb875631816606515b3e047c --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dfeeef0e0e812038c52f838b994c631faa236af0d360246951dfc3e07ab0a461 +size 155756888 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/cutlass.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/cutlass.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/marlin.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/marlin.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/platforms.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/scalar_type.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/scalar_type.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/__init__.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index fa34aa0ab7aed49eef08bab955e354b2af3cf583..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:a526fafcee66a673aa3aa460105ae3374d30e3fade68ab4cabd995482e467258 -size 68275920 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..3c8b93289b32e964434efeacf924c3bcd9e323a5 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2d6c023d7381396997b58ff6bdaa002db2ab94a0c0eb17d09512a1a9f8e888d2 +size 160280720 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/cutlass.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/cutlass.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/marlin.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/marlin.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/platforms.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/scalar_type.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/scalar_type.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/__init__.py index c3ab3b032c29f7bbafd549915dbc677c45a33837..5ffacf871b3c23e224da813f506996f2a755d3e1 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/__init__.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/__init__.py @@ -1,12 +1,12 @@ from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant from .cutlass import ( + cutlass_scaled_mm_supports_block_fp8, cutlass_scaled_mm_supports_fp8, cutlass_scaled_mm, cutlass_scaled_mm_azp, ) from .marlin import ( awq_marlin_repack, - fp8_marlin_gemm, gptq_marlin_gemm, gptq_marlin_repack, gptq_marlin_24_gemm, @@ -25,8 +25,8 @@ __all__ = [ "awq_marlin_repack", "cutlass_scaled_mm", "cutlass_scaled_mm_azp", + "cutlass_scaled_mm_supports_block_fp8", "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", "gptq_marlin_24_gemm", "gptq_marlin_gemm", "gptq_marlin_repack", diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/_ops.py index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..723548885679ef828d1fee0c6b4907c25ef867e1 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb +from . import _quantization_e8730d8_dirty +ops = torch.ops._quantization_e8730d8_dirty def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file + return f"_quantization_e8730d8_dirty::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index 9fd6a6b64cac1e87e3f208198312e8b492ba8fd6..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c3f633e0aa3070b9adcea2e365f0c19ad9c814a2fd7d5f9c82540b6e755cfa09 -size 121794816 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..9cb2d5cfba82f74d9a13e944bc33b1ab31da38d7 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/_quantization_e8730d8_dirty.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:718b7895b3e802aee133dcdbdbfd4aafa1dfed30a7a2b08547d97ec738b29c6e +size 297107160 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/compressed_tensors.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/compressed_tensors.py index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..94f6ed8dc68b104b620fa88ea7c41700c6b277dd 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/compressed_tensors.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/compressed_tensors.py @@ -2,17 +2,7 @@ from typing import Optional, Tuple import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - +from ._ops import ops # fp8 def scaled_fp8_quant( @@ -21,7 +11,8 @@ def scaled_fp8_quant( num_token_padding: Optional[int] = None, scale_ub: Optional[torch.Tensor] = None, use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: + output: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP8 and return quantized tensor and scale. @@ -42,30 +33,36 @@ def scaled_fp8_quant( in the dynamic quantization case. Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and + tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and scaling factor. """ # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn + assert (input.ndim == 2) + shape: Union[tuple[int, int], torch.Size] = input.shape + # For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz + out_dtype: torch.dtype = current_platform.fp8_dtype() if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) + if output is None: + output = torch.empty(shape, device=input.device, dtype=out_dtype) + else: + assert num_token_padding is None, \ + "padding not supported if output passed in" + assert output.dtype == out_dtype if scale is None: if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) + scale = torch.empty((shape[0], 1), + device=input.device, + dtype=torch.float32) + ops.dynamic_per_token_scaled_fp8_quant( + output, input.contiguous(), scale, scale_ub) else: scale = torch.zeros(1, device=input.device, dtype=torch.float32) ops.dynamic_scaled_fp8_quant(output, input, scale) else: # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None + assert (scale.numel() == 1 and num_token_padding is None) ops.static_scaled_fp8_quant(output, input, scale) return output, scale @@ -76,8 +73,8 @@ def scaled_int8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + symmetric: bool = True +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. @@ -90,21 +87,25 @@ def scaled_int8_quant( symmetric: Whether to use symmetric quantization (scale only, azp ignored). Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. """ output = torch.empty_like(input, dtype=torch.int8) if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." ops.static_scaled_int8_quant(output, input, scale, azp) return output, scale, azp # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + ops.dynamic_scaled_int8_quant(output, input.contiguous(), + input_scales, input_azp) return output, input_scales, input_azp + + diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/cutlass.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/cutlass.py index c378b846d0c59de183a321fcad4b403c47b3d750..3c4efc4a93effdb140ecd0ee4b7608caeb3a35d0 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/cutlass.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/cutlass.py @@ -2,22 +2,18 @@ from typing import Optional import torch -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e +from ._ops import ops +from .platforms import current_platform def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return ops.cutlass_scaled_mm_supports_block_fp8(cuda_device_capability) + + def cutlass_scaled_mm( a: torch.Tensor, b: torch.Tensor, @@ -33,12 +29,10 @@ def cutlass_scaled_mm( m = a.shape[0] n = b.shape[1] - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) + if not cutlass_compatible_b: + from .triton_scaled_mm import triton_scaled_mm + return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) out = torch.empty((m, n), dtype=out_dtype, device=a.device) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/marlin.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/marlin.py index 44d5d28a2fb67af955c017af3cf1403feeecbd32..f4d5173599faebb6f31392960110ceec20346d75 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/marlin.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/marlin.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -30,58 +30,30 @@ except ImportError as e: from .scalar_type import ScalarType -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - # gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - +def gptq_marlin_gemm(a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type: ScalarType, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: + return ops.gptq_marlin_gemm(a, c, b_q_weight, b_scales, + global_scale, b_zeros, g_idx, perm, + workspace, b_q_type.id, size_m, + size_n, size_k, is_k_full, + use_atomic_add, use_fp32_reduce, + is_zp_float) # gptq_marlin def gptq_marlin_repack( @@ -153,14 +125,6 @@ def marlin_qqq_gemm( # Fake ops if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, b_meta: torch.Tensor, b_scales: torch.Tensor, @@ -172,20 +136,22 @@ if hasattr(ops, "gptq_marlin_24_gemm"): @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: torch.SymInt, + size_n: torch.SymInt, + size_k: torch.SymInt, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/platforms.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/platforms.py new file mode 100644 index 0000000000000000000000000000000000000000..735fab87f2add390f7bf6408ebe31d1f5de6d02b --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/platforms.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import NamedTuple + +import torch + +IS_ROCM = torch.version.hip is not None + + +class DeviceCapability(NamedTuple): + major: int + minor: int + + def as_version_str(self) -> str: + return f"{self.major}.{self.minor}" + + def to_int(self) -> int: + """ + Express device capability as an integer ````. + + It is assumed that the minor version is always a single digit. + """ + assert 0 <= self.minor < 10 + return self.major * 10 + self.minor + + +class Platform(ABC): + simple_compile_backend: str = "inductor" + + @classmethod + @abstractmethod + def get_device_name(cls, device_id: int = 0) -> str: ... + + @abstractmethod + def is_rocm(self): ... + + +class CudaPlatform(Platform): + @classmethod + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(0) + + def is_rocm(self): + return False + + +class RocmPlatform(Platform): + @classmethod + @lru_cache(maxsize=8) + def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: + major, minor = torch.cuda.get_device_capability(device_id) + return DeviceCapability(major=major, minor=minor) + + @classmethod + @lru_cache(maxsize=8) + def get_device_name(cls, device_id: int = 0) -> str: + return torch.cuda.get_device_name(device_id) + + def is_rocm(self): + return True + + +current_platform = RocmPlatform() if IS_ROCM else CudaPlatform() diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/scalar_type.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/scalar_type.py index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..9060b55c79b0185c5dd22a2940e4363f69ccabeb 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/scalar_type.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/scalar_type.py @@ -1,9 +1,14 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import functools import struct from dataclasses import dataclass from enum import Enum from typing import Optional, Union +_SCALAR_TYPES_ID_MAP = {} + # Mirrors enum in `core/scalar_type.hpp` class NanRepr(Enum): @@ -121,8 +126,8 @@ class ScalarType: min_raw = max_raw | sign_bit_double return struct.unpack('!d', struct.pack('!Q', min_raw))[0] else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" + assert (not self.is_signed() or self.size_bits + <= 64), "Cannot represent min as a int64_t" if self.is_signed(): return -(1 << (self.size_bits - 1)) @@ -156,6 +161,8 @@ class ScalarType: assert offset <= 64, \ f"ScalarType fields too big {offset} to fit into an int64" + _SCALAR_TYPES_ID_MAP[val] = self + return val @property @@ -293,6 +300,13 @@ class ScalarType: ret.id # noqa B018: make sure the id is cached return ret + @classmethod + def from_id(cls, scalar_type_id: int): + if scalar_type_id not in _SCALAR_TYPES_ID_MAP: + raise ValueError( + f"scalar_type_id {scalar_type_id} doesn't exists.") + return _SCALAR_TYPES_ID_MAP[scalar_type_id] + # naming generally follows: https://github.com/jax-ml/ml_dtypes # for floating point types (leading f) the scheme is: @@ -319,6 +333,9 @@ class scalar_types: # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) + # "gptq" types uint2b2 = ScalarType.uint(2, 2) uint3b4 = ScalarType.uint(3, 4) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils.py index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..eb2f41d72984bdfbe03a6d71b632371025156448 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils.py @@ -1,4 +1,7 @@ -from typing import List, Optional, Tuple +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional import numpy import torch @@ -42,7 +45,9 @@ USE_FP32_REDUCE_DEFAULT = True # without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. # TODO: we may want to move this into the C++ so its closer to the actual impl def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None + has_zp: Optional[bool] = None, + include_fp_type: bool = True, + device_capability: Optional[int] = None, ): if device_capability is None: capability_tuple = torch.cuda.get_device_capability() @@ -51,137 +56,141 @@ def query_marlin_supported_quant_types( if device_capability < 80: return [] + # - has_zp is True: return quant_types that has zero points + # - has_zp is False: return quant_types that has not zero points + # - has_zp is None: both + if has_zp is None: + types0 = query_marlin_supported_quant_types(False, include_fp_type, + device_capability) + types1 = query_marlin_supported_quant_types(True, include_fp_type, + device_capability) + return types0 + types1 + if has_zp: # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] + return [scalar_types.uint4] else: # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] + res = [scalar_types.uint4b8, scalar_types.uint8b128] + if include_fp_type: + res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f] + return res def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> tuple[bool, Optional[str]]: if device_capability is None: capability_tuple = torch.cuda.get_device_capability() device_capability = capability_tuple[0] * 10 + capability_tuple[1] - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) + supported_types = query_marlin_supported_quant_types( + has_zp, True, device_capability) if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") return True, None -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) return cond -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) if not cond: assert err_msg is not None raise ValueError(err_msg) -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: # Validate output_size_per_partition if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") # Validate input_size_per_partition if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): raise ValueError( f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." + f" is not divisible by group_size = {group_size}. " "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) + "with --quantization gptq.") -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> tuple[bool, Optional[str]]: try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) except ValueError as e: return False, e.__str__() return True, None -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_make_workspace_new(device: torch.device, + max_blocks_per_sm: int = 1) -> torch.Tensor: + # In the new marlin kernel, we use the num of threadblocks as workspace + # size. The num of threadblocks is is sms_count * max_blocks_per_sm. + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms * max_blocks_per_sm, + dtype=torch.int, + device=device, + requires_grad=False) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: return (not act_order) or (act_order and not is_row_parallel) -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: # Need to repeat scales on every rank if act_ordering or # channelwise and RowParallelLinear is_channelwise = group_size == -1 @@ -189,35 +198,34 @@ def marlin_repeat_scales_on_all_ranks( def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) return g_idx[g_idx_sort_indices], g_idx_sort_indices def get_scale_perms(): - scale_perm: List[int] = [] + scale_perm: list[int] = [] for i in range(8): scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] + scale_perm_single: list[int] = [] for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) return scale_perm, scale_perm_single -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: scale_perm, scale_perm_single = get_scale_perms() if group_size < size_k and group_size != -1: @@ -247,9 +255,8 @@ def marlin_moe_permute_scales( return output -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the # "single" permutation, since zero-points are applied on every MMA scale_perm, _ = get_scale_perms() @@ -270,9 +277,8 @@ def marlin_zero_points( return zp -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: # AWQ zero-points are quantized and packed on the column dim. # In addition, the values are permuted based on dequantizer. # Here we undo both of these, and then apply marlin permutation @@ -294,9 +300,8 @@ def awq_to_marlin_zero_points( return marlin_zp -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): num_experts = q_zp_packed.shape[0] output = torch.empty( (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), @@ -304,45 +309,97 @@ def moe_awq_to_marlin_zero_points( dtype=q_zp_packed.dtype, ) for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) return output +def maybe_warn_marlin_atomic_add(device, dtype): + if torch.compiler.is_dynamo_compiling(): + return + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + logger.info_once( + "You are running Marlin kernel with bf16 on GPUs before SM90. " + "You can consider change to fp16 to achieve better performance " + "if possible.") + + +def maybe_warn_marlin_atomic_add_env(): + if torch.compiler.is_dynamo_compiling(): + return + if envs.VLLM_MARLIN_USE_ATOMIC_ADD: + return + logger.info_once( + "Marlin kernel can achieve better performance for small size_n " + "with experimental use_atomic_add feature. " + "You can consider set environment variable " + "VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible.") + + +def should_use_atomic_add_reduce(m: int, n: int, k: int, device: torch.device, + dtype: torch.dtype) -> bool: + + # the performance of atomicAdd is better than global reduce + # only when m*n is small and k is large + if n >= 2048 or k < 2048 or device.type != "cuda": + return False + + # disable atomicAdd reduce by default, + # one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1 + if not envs.VLLM_MARLIN_USE_ATOMIC_ADD: + maybe_warn_marlin_atomic_add_env() + return False + + # sm8x doesn't support atomicAdd + bfloat16 natively + device_capability = torch.cuda.get_device_capability(device) + if device_capability[0] < 9 and dtype == torch.bfloat16: + maybe_warn_marlin_atomic_add(device, dtype) + return False + + return True + + def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add @@ -351,39 +408,43 @@ def apply_gptq_marlin_linear( def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=output_size_per_partition, + k=reshaped_x.size(1), + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(reshaped_x, + None, + weight, + weight_scale, + None, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) if bias is not None: output.add_(bias) # In-place add diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils_fp4.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils_fp4.py new file mode 100644 index 0000000000000000000000000000000000000000..b6697e1394328f52681dd2b8870fe826d9be5ba3 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils_fp4.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +import torch + +import quantization as ops + +from .marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, marlin_make_workspace_new, marlin_permute_scales, + should_use_atomic_add_reduce) +from quantization.scalar_type import scalar_types + +FP4_MARLIN_SUPPORTED_GROUP_SIZES = [16] + + +def is_fp4_marlin_supported(): + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + return capability >= 80 + + +def fp4_marlin_process_scales(marlin_scales): + if not (marlin_scales >= 0).all(): + logger.warning_once( + "NVFP4 Marlin assumes the scales to be >=0, but has encountered " + "negative scales. Accuracy will likely be degraded. This is " + "because it changes the scales from FP8-S1E4M3 to a special " + "FP8-S0E5M3 format to speedup the dequantization.") + + # convert to half first, we would convert to fp8 later + marlin_scales = marlin_scales.to(torch.half) + + # 8 is the number of scale number using by one thread + marlin_scales = marlin_scales.view(marlin_scales.size(0) // 2, 2, -1, 8) + marlin_scales = marlin_scales.permute(0, 2, 1, 3).reshape( + marlin_scales.size(0) * 2, -1) + + # fit the layout of fp8 dequantization + marlin_scales = marlin_scales.view(-1, 4)[:, [0, 2, 1, 3]].view( + marlin_scales.size(0), -1) + + # We assume that weight_scale (FP8-S1E4M3) is always greater + # than or equal to 0. So we can convert + # (weight_scale * (2 ** 7) to a special FP8-S0E5M3 format. + # After multiplying by 2 ** 7, the top bit of FP8-S0E5M3 would always be 1 + # when weight_scale > 0. This allows us to have an exponent bias + # closer to zero after dequantization. + + marlin_scales = (marlin_scales * (2**7)).view(torch.int16) << 1 + marlin_scales = marlin_scales.view(torch.float8_e4m3fn) + marlin_scales = marlin_scales[:, 1::2].contiguous() + + return marlin_scales + + +def fp4_marlin_process_global_scale(global_scale): + assert global_scale.dtype in [torch.half, torch.bfloat16] + fp4_exponent = 2 + if global_scale.dtype == torch.half: + target_exponent = 5 + elif global_scale.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 1 = 14 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 1 = 126 + exponent_bias = 2**(target_exponent - 1) - 2**(fp4_exponent - 1) + return global_scale * (2.0**(exponent_bias - 7)) + + +def apply_fp4_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_scale_2: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + # For GPUs that lack FP4 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP4 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=weight_scale_2, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float4_e2m1f, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + param_dtype = layer.params_dtype + + assert layer.weight.shape == (part_size_n, part_size_k // 2) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace_new(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = layer.weight.view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=4) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + weight_scale = layer.weight_scale.T.to(param_dtype) + weight_scale = marlin_permute_scales(s=weight_scale, + size_k=part_size_k, + size_n=part_size_n, + group_size=16) + weight_scale = fp4_marlin_process_scales(weight_scale) + layer.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) + + weight_scale_2 = layer.weight_scale_2.to(param_dtype) + weight_scale_2 = fp4_marlin_process_global_scale(weight_scale_2) + layer.weight_scale_2 = torch.nn.Parameter(weight_scale_2, + requires_grad=False) + + return + + +def prepare_moe_fp4_layer_for_marlin(layer: torch.nn.Module) -> None: + logger.warning_once( + "Your GPU does not have native support for FP4 computation but " + "FP4 quantization is being used. Weight-only FP4 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + + # WORKSPACE + device = layer.w13_weight.device + param_dtype = layer.params_dtype + layer.workspace = marlin_make_workspace_new(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + assert weight.shape == (e, size_n, size_k // 2) + + for i in range(e): + qweight = weight[i].view(torch.int32).T.contiguous() + + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=qweight, + perm=perm, + size_k=size_k, + size_n=size_n, + num_bits=4) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + for name in ["w13", "w2"]: + scales = getattr(layer, name + "_weight_scale").to(param_dtype) + global_scale = getattr(layer, name + "_weight_scale_2").to(param_dtype) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + for i in range(e): + marlin_scales = marlin_permute_scales(s=scales[i].T, + size_k=size_k, + size_n=size_n, + group_size=16) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = torch.nn.Parameter(scales, requires_grad=False) + setattr(layer, name + "_weight_scale", scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + global_scale = torch.nn.Parameter(global_scale, requires_grad=False) + setattr(layer, name + "_weight_scale_2", global_scale) + + +def rand_marlin_weight_fp4_like(weight, group_size): + assert group_size > 0 + size_n, size_k = weight.shape + device = weight.device + + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 6 + global_scale = scales.max() / 448 + scales = (scales / global_scale).to(torch.float8_e4m3fn) + + fp4_weight = torch.randint(0, + 256, (size_n, size_k // 2), + dtype=torch.uint8, + device=weight.device) + fp4_weight_part_1 = ((fp4_weight & 0b10000000) | + ((fp4_weight & 0b01110000) >> 2)) + fp4_weight_part_1 = fp4_weight_part_1.view(torch.float8_e4m3fn) + fp4_weight_part_1 = fp4_weight_part_1.to(weight.dtype) * (2**6) + + fp4_weight2 = fp4_weight << 4 + fp4_weight_part_2 = ((fp4_weight2 & 0b10000000) | + ((fp4_weight2 & 0b01110000) >> 2)) + fp4_weight_part_2 = fp4_weight_part_2.view(torch.float8_e4m3fn) + fp4_weight_part_2 = fp4_weight_part_2.to(weight.dtype) * (2**6) + + weight_ref = torch.cat( + [fp4_weight_part_2.unsqueeze(2), + fp4_weight_part_1.unsqueeze(2)], 2).view(size_n, size_k) + weight_ref = weight_ref * global_scale.to(weight.dtype) * \ + scales.repeat_interleave(group_size, 1).to(weight.dtype) + + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=fp4_weight.view(torch.int32).T.contiguous(), + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=4, + ) + + marlin_scales = marlin_permute_scales(s=scales.T.to(weight.dtype), + size_k=size_k, + size_n=size_n, + group_size=group_size) + marlin_scales = fp4_marlin_process_scales(marlin_scales) + + global_scale = fp4_marlin_process_global_scale(global_scale) + + return weight_ref.T, marlin_qweight, marlin_scales, global_scale diff --git a/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils_fp8.py index b269fa6a4cee316e8299ecc86c3e7594b336b499..b38fe2d4aff0234cdbf08218da6440b4892e01f0 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/quantization/utils/marlin_utils_fp8.py @@ -1,10 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + from typing import Optional import torch import quantization as ops -from .marlin_utils import marlin_make_workspace, marlin_permute_scales +from .marlin_utils import USE_FP32_REDUCE_DEFAULT, marlin_make_workspace, marlin_permute_scales def is_fp8_marlin_supported(): @@ -13,88 +16,107 @@ def is_fp8_marlin_supported(): return capability >= 80 +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2**(target_exponent - 1) - 2**(fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: # For GPUs that lack FP8 hardware support, we can leverage the # Marlin kernel for fast weight-only FP8 quantization reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) + out_shape = input.shape[:-1] + (size_n, ) + + use_atomic_add = should_use_atomic_add_reduce(m=reshaped_x.size(0), + n=size_n, + k=size_k, + device=input.device, + dtype=input.dtype) + + output = ops.gptq_marlin_gemm(a=reshaped_x, + c=None, + b_q_weight=weight, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce) if bias is not None: output.add_(bias) # In-place add return output.reshape(out_shape) - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: +def pack_fp8_to_int32(fp8_tensor: torch.Tensor, + size_k_first: bool = True) -> torch.Tensor: """ Repack FP8 weights to gptq format (packed int32 elements) """ assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = ops.gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) + marlin_scales = marlin_permute_scales(s=scales.T, + size_k=size_k, + size_n=size_n, + group_size=group_size) - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() + return weight_ref.T, marlin_qweight, marlin_scales diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/__init__.py deleted file mode 100644 index c3ab3b032c29f7bbafd549915dbc677c45a33837..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -from .compressed_tensors import scaled_fp8_quant, scaled_int8_quant -from .cutlass import ( - cutlass_scaled_mm_supports_fp8, - cutlass_scaled_mm, - cutlass_scaled_mm_azp, -) -from .marlin import ( - awq_marlin_repack, - fp8_marlin_gemm, - gptq_marlin_gemm, - gptq_marlin_repack, - gptq_marlin_24_gemm, - marlin_qqq_gemm, - marlin_gemm, -) -from .scalar_type import ( - ScalarType, - scalar_types, -) -from ._ops import ops - - -__all__ = [ - "ScalarType", - "awq_marlin_repack", - "cutlass_scaled_mm", - "cutlass_scaled_mm_azp", - "cutlass_scaled_mm_supports_fp8", - "fp8_marlin_gemm", - "gptq_marlin_24_gemm", - "gptq_marlin_gemm", - "gptq_marlin_repack", - "marlin_gemm", - "marlin_qqq_gemm", - "ops", - "scalar_types", - "scaled_fp8_quant", - "scaled_int8_quant", -] diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/_ops.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/_ops.py deleted file mode 100644 index 6d88357f35586e1c68d6b98e547d6cf7d8dc4083..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/_ops.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -from . import _quantization_0435ccb -ops = torch.ops._quantization_0435ccb - -def add_op_namespace_prefix(op_name: str): - """ - Prefix op by namespace. - """ - return f"_quantization_0435ccb::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/_quantization_0435ccb.abi3.so b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/_quantization_0435ccb.abi3.so deleted file mode 100755 index b5ba4411808e6d751eb556d45f4d4e6ff7b3ba8a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/_quantization_0435ccb.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5d4ff6b2fec2a60c8bf1de4fb05d66bdd05534610fb5bd175dfa77a483590f21 -size 2857144 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/compressed_tensors.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/compressed_tensors.py deleted file mode 100644 index c3ba30bac87979a307fc5061a46f5d2cbf0efbf9..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/compressed_tensors.py +++ /dev/null @@ -1,110 +0,0 @@ -from typing import Optional, Tuple - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -# fp8 -def scaled_fp8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - num_token_padding: Optional[int] = None, - scale_ub: Optional[torch.Tensor] = None, - use_per_token_if_dynamic: bool = False, -) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantize input tensor to FP8 and return quantized tensor and scale. - - This function supports both static and dynamic quantization: If you - provide the scale, it will use static scaling and if you omit it, - the scale will be determined dynamically. The function also allows - optional padding of the output tensors for downstream kernels that - will benefit from padding. - - Args: - input: The input tensor to be quantized to FP8 - scale: Optional scaling factor for the FP8 quantization - scale_ub: Optional upper bound for scaling factor in dynamic - per token case - num_token_padding: If specified, pad the first dimension - of the output to at least this value. - use_per_token_if_dynamic: Whether to do per_tensor or per_token - in the dynamic quantization case. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and - scaling factor. - """ - # This code assumes batch_dim and num_tokens are flattened - assert input.ndim == 2 - shape: Union[Tuple[int, int], torch.Size] = input.shape - # For rocm, the output fp8 dtype is torch.float_e3m3fnuz - # out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - # if current_platform.is_rocm() else torch.float8_e4m3fn - out_dtype = torch.float8_e4m3fn - if num_token_padding: - shape = (max(num_token_padding, input.shape[0]), shape[1]) - output = torch.empty(shape, device=input.device, dtype=out_dtype) - - if scale is None: - if use_per_token_if_dynamic: - scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32) - ops.dynamic_per_token_scaled_fp8_quant(output, input, scale, scale_ub) - else: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) - ops.dynamic_scaled_fp8_quant(output, input, scale) - else: - # num_token_padding not implemented for this case - assert scale.numel() == 1 or num_token_padding is None - ops.static_scaled_fp8_quant(output, input, scale) - - return output, scale - - -# int8 -def scaled_int8_quant( - input: torch.Tensor, - scale: Optional[torch.Tensor] = None, - azp: Optional[torch.Tensor] = None, - symmetric: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. - - Args: - input: The input tensor to be quantized to int8. - scale: Optional scaling factor for the int8 quantization. - When not provided, we invoke dynamic-per-token quantization. - azp: Optional zero-point for the int8 quantization. - Must be provided for asymmetric quantization if `scale` is provided. - symmetric: Whether to use symmetric quantization (scale only, azp ignored). - - Returns: - Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. - """ - output = torch.empty_like(input, dtype=torch.int8) - if scale is not None: - # static-per-tensor quantization. - assert symmetric == ( - azp is None - ), "azp must only be provided for asymmetric quantization." - ops.static_scaled_int8_quant(output, input, scale, azp) - return output, scale, azp - - # dynamic-per-token quantization. - input_scales = torch.empty( - (input.numel() // input.shape[-1], 1), device=input.device, dtype=torch.float32 - ) - input_azp = None if symmetric else torch.empty_like(input_scales, dtype=torch.int32) - ops.dynamic_scaled_int8_quant(output, input, input_scales, input_azp) - return output, input_scales, input_azp diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/cutlass.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/cutlass.py deleted file mode 100644 index c378b846d0c59de183a321fcad4b403c47b3d750..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/cutlass.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -import torch - -try: - from ._ops import ops -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - except ImportError: - raise e - - -def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: - return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - - -def cutlass_scaled_mm( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.shape[0] == b.shape[1] and bias.dtype == out_dtype - - m = a.shape[0] - n = b.shape[1] - - # if current_platform.is_rocm(): - # triton_scaled_mm_module = importlib.import_module( - # "vllm.model_executor.layers.quantization.compressed_tensors." - # "triton_scaled_mm") - # triton_scaled_mm = triton_scaled_mm_module.triton_scaled_mm - # return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) - - return out - - -def cutlass_scaled_mm_azp( - a: torch.Tensor, - b: torch.Tensor, - scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype, - azp_adj: torch.Tensor, - azp: Optional[torch.Tensor] = None, - bias: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - :param azp_adj: In the per-tensor case, this should include the azp. - Always per-channel. - :param azp: Only set in the per-token case. Per-token if set. - """ - assert b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 - assert out_dtype is torch.bfloat16 or out_dtype is torch.float16 - assert bias is None or bias.numel() == b.shape[1] and bias.dtype == out_dtype - assert azp is None or azp.numel() == a.shape[0] - - m = a.shape[0] - n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - - ops.cutlass_scaled_mm_azp(out, a, b, scale_a, scale_b, azp_adj, azp, bias) - return out diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/marlin.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/marlin.py deleted file mode 100644 index 44d5d28a2fb67af955c017af3cf1403feeecbd32..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/marlin.py +++ /dev/null @@ -1,208 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -# neuron has torch version that doesn't even have impl_abstract -if TYPE_CHECKING: - def register_fake(fn): - return lambda name: fn -else: - try: - from torch.library import register_fake - except ImportError: - from torch.library import impl_abstract as register_fake - -try: - from ._ops import ops, add_op_namespace_prefix -except ImportError as e: - # Fallback for local development. - try: - import _quantization - - ops = torch.ops._quantization - - def add_op_namespace_prefix(op_name: str): - return f"_quantization::{op_name}" - except ImportError: - raise e - - -from .scalar_type import ScalarType - - -# fp8 marlin -def fp8_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.fp8_marlin_gemm( - a, b_q_weight, b_scales, workspace, num_bits, size_m, size_n, size_k - ) - - -# gptq_marlin -def gptq_marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False, -) -> torch.Tensor: - return ops.gptq_marlin_gemm( - a, - b_q_weight, - b_scales, - b_zeros, - g_idx, - perm, - workspace, - b_q_type.id, - size_m, - size_n, - size_k, - is_k_full, - has_zp, - use_fp32_reduce, - is_zp_float, - ) - - -# gptq_marlin -def gptq_marlin_repack( - b_q_weight: torch.Tensor, - perm: torch.Tensor, - size_k: int, - size_n: int, - num_bits: int, -) -> torch.Tensor: - return ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits) - - -# gptq_marlin -def awq_marlin_repack( - b_q_weight: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - return ops.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) - - -# marlin -def marlin_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_gemm( - a, b_q_weight, b_scales, workspace, size_m, size_n, size_k - ) - - -# marlin_24 -def gptq_marlin_24_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.gptq_marlin_24_gemm( - a, b_q_weight, b_meta, b_scales, workspace, b_q_type.id, size_m, size_n, size_k - ) - - -# qqq ops -def marlin_qqq_gemm( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: int, - size_n: int, - size_k: int, -) -> torch.Tensor: - return ops.marlin_qqq_gemm( - a, b_q_weight, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k - ) - - -# Fake ops - -if hasattr(ops, "gptq_marlin_24_gemm"): - @register_fake(add_op_namespace_prefix("fp8_marlin_gemm")) - def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - num_bits: int, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake(add_op_namespace_prefix("gptq_marlin_24_gemm")) - def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_meta: torch.Tensor, b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("gptq_marlin_gemm")) - def _gptq_marlin_gemm_fake(a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type: ScalarType, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - is_zp_float: bool = False) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake(add_op_namespace_prefix("marlin_qqq_gemm")) - def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - s_tok: torch.Tensor, s_ch: torch.Tensor, - s_group: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) - - @register_fake(add_op_namespace_prefix("marlin_gemm")) - def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, - b_scales: torch.Tensor, workspace: torch.Tensor, - size_m: torch.SymInt, size_n: torch.SymInt, - size_k: torch.SymInt) -> torch.Tensor: - return torch.empty((size_m, size_n), - dtype=torch.float16, - device=a.device) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/scalar_type.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/scalar_type.py deleted file mode 100644 index 9d711b0debcd8aaa343818edc9d6bbca20587d0a..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/scalar_type.py +++ /dev/null @@ -1,330 +0,0 @@ -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN - or self.nan_repr == NanRepr.NONE): - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = (max_exponent - exponent_bias + - exponent_bias_double) - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << - (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack('!d', struct.pack('!Q', double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert (self.size_bits < 64 or self.size_bits == 64 - and self.is_signed()), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert self.is_signed( - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack('!d', struct.pack('!Q', min_raw))[0] - else: - assert (not self.is_signed() or - self.size_bits <= 64), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, \ - f"ScalarType fields too big {offset} to fit into an int64" - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and \ - not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = "float" + str(self.size_bits) + "_e" + str( - self.exponent) + "m" + str(self.mantissa) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: NanRepr) -> 'ScalarType': - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert (mantissa > 0 and exponent > 0) - assert (nan_repr != NanRepr.IEEE_754), ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions") - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/__init__.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils.py deleted file mode 100644 index b1c94c38858a5cd6f02eb134d1a94b99a2b15566..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils.py +++ /dev/null @@ -1,391 +0,0 @@ -from typing import List, Optional, Tuple - -import numpy -import torch - -import quantization as ops -from quantization.scalar_type import ScalarType, scalar_types - -from .quant_utils import pack_cols, unpack_cols - -GPTQ_MARLIN_TILE = 16 -GPTQ_MARLIN_MIN_THREAD_N = 64 -GPTQ_MARLIN_MIN_THREAD_K = 128 -GPTQ_MARLIN_MAX_PARALLEL = 16 - -GPTQ_MARLIN_24_TILE = 16 -GPTQ_MARLIN_24_MIN_THREAD_N = 128 -GPTQ_MARLIN_24_MIN_THREAD_K = 128 -GPTQ_MARLIN_24_MAX_PARALLEL = 64 - -GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128] - -MARLIN_QQQ_TILE = 16 -MARLIN_QQQ_MIN_THREAD_N = 64 -MARLIN_QQQ_MIN_THREAD_K = 128 -MARLIN_QQQ_MAX_PARALLEL = 16 - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] -MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] -MARLIN_QQQ_SUPPORTED_SYM = [True] - -MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -# In case there is a performance issue with Marlin, the variable below can be -# changed to False, which allows Marlin to perform global reductions in fp16 -# precision (instead of fp32), and therefore, save on some memory movements. -USE_FP32_REDUCE_DEFAULT = True - - -# For binary size and compile time, we don't support the same types for with and -# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. -# TODO: we may want to move this into the C++ so its closer to the actual impl -def query_marlin_supported_quant_types( - has_zp: bool, device_capability: Optional[int] = None -): - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - if device_capability < 80: - return [] - - if has_zp: - # AWQ style, unsigned + runtime zero-point - return [scalar_types.uint4, scalar_types.uint8] - else: - # GPTQ style, unsigned + symmetric bias - # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able - # to add `scalar_types.float8_e4m3fn` here - return [scalar_types.uint4b8, scalar_types.uint8b128] - - -def _check_marlin_supported( - quant_type: ScalarType, - group_size: Optional[int], - has_zp: bool, - device_capability: Optional[int] = None, -) -> Tuple[bool, Optional[str]]: - - if device_capability is None: - capability_tuple = torch.cuda.get_device_capability() - device_capability = capability_tuple[0] * 10 + capability_tuple[1] - - supported_types = query_marlin_supported_quant_types(has_zp, device_capability) - - if quant_type not in supported_types: - return ( - False, - f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).", - ) - if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES: - return ( - False, - f"Marlin does not support group_size = {group_size}. " - f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " - "are supported.", - ) - - return True, None - - -def check_marlin_supported( - quant_type: ScalarType, - group_size: int, - has_zp: bool = False, - device_capability: Optional[int] = None, -) -> bool: - cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability) - return cond - - -def verify_marlin_supported( - quant_type: ScalarType, group_size: int, has_zp: bool = False -) -> None: - cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) - if not cond: - assert err_msg is not None - raise ValueError(err_msg) - - -def verify_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> None: - - # Validate output_size_per_partition - if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: - raise ValueError( - f"Weight output_size_per_partition = " - f"{output_size_per_partition} is not divisible by " - f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - # Validate input_size_per_partition - if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: - raise ValueError( - f"Weight input_size_per_partition = " - f"{input_size_per_partition} is not divisible " - f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - if group_size < input_size and input_size_per_partition % group_size != 0: - raise ValueError( - f"Weight input_size_per_partition = {input_size_per_partition}" - f" is not divisible by group_size = {group_size}." - "Consider reducing tensor_parallel_size or running " - "with --quantization gptq." - ) - - -def check_marlin_supports_shape( - output_size_per_partition: int, - input_size_per_partition: int, - input_size: int, - group_size: int, -) -> Tuple[bool, Optional[str]]: - try: - verify_marlin_supports_shape( - output_size_per_partition, input_size_per_partition, input_size, group_size - ) - except ValueError as e: - return False, e.__str__() - return True, None - - -def marlin_make_workspace( - output_size_per_partition: int, device: torch.device -) -> torch.Tensor: - max_workspace_size = ( - output_size_per_partition // GPTQ_MARLIN_MIN_THREAD_N - ) * GPTQ_MARLIN_MAX_PARALLEL - - return torch.zeros( - max_workspace_size, dtype=torch.int, device=device, requires_grad=False - ) - - -def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: - return (not act_order) or (act_order and not is_row_parallel) - - -def marlin_repeat_scales_on_all_ranks( - act_order: bool, group_size: int, is_row_parallel: bool -) -> bool: - # Need to repeat scales on every rank if act_ordering or - # channelwise and RowParallelLinear - is_channelwise = group_size == -1 - return act_order or (is_channelwise and is_row_parallel) - - -def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: - return torch.nn.Parameter( - torch.empty(0, dtype=torch.int, device=device), requires_grad=False - ) - - -def marlin_sort_g_idx(g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) - return g_idx[g_idx_sort_indices], g_idx_sort_indices - - -def get_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -def marlin_permute_scales( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_moe_permute_scales( - s: torch.Tensor, - size_k: int, - size_n: int, - group_size: int, -): - num_experts = s.shape[0] - output = torch.empty( - (num_experts, s.shape[1], s.shape[2]), - device=s.device, - dtype=s.dtype, - ) - - for e in range(num_experts): - output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) - return output - - -def marlin_zero_points( - zp: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # Permute zero-points in a similar way to scales, but do not use the - # "single" permutation, since zero-points are applied on every MMA - scale_perm, _ = get_scale_perms() - zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() - zp = zp.reshape((-1, size_n)).contiguous() - zp = pack_cols(zp, num_bits, size_k, size_n) - - return zp - - -def awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -) -> torch.Tensor: - # AWQ zero-points are quantized and packed on the column dim. - # In addition, the values are permuted based on dequantizer. - # Here we undo both of these, and then apply marlin permutation - # and pack it back. - q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) - - # Undo interleaving (use argsort(..) to get inverse perm) - if num_bits == 4: - undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) - elif num_bits == 8: - undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() - q_zp = q_zp.reshape((-1, size_n)).contiguous() - - marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) - return marlin_zp - - -def moe_awq_to_marlin_zero_points( - q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int -): - num_experts = q_zp_packed.shape[0] - output = torch.empty( - (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), - device=q_zp_packed.device, - dtype=q_zp_packed.dtype, - ) - for e in range(num_experts): - output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits) - return output - - -def apply_gptq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - wtype: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - is_k_full: bool, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - wtype, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=is_k_full, - has_zp=False, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def apply_awq_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - weight_zp: torch.Tensor, - g_idx: torch.Tensor, - g_idx_sort_indices: torch.Tensor, - workspace: torch.Tensor, - quant_type: ScalarType, - output_size_per_partition: int, - input_size_per_partition: int, - bias: Optional[torch.Tensor] = None, - use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, -) -> torch.Tensor: - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (output_size_per_partition,) - - output = ops.gptq_marlin_gemm( - reshaped_x, - weight, - weight_scale, - weight_zp, - g_idx, - g_idx_sort_indices, - workspace, - quant_type, - size_m=reshaped_x.shape[0], - size_n=output_size_per_partition, - size_k=input_size_per_partition, - is_k_full=True, - has_zp=True, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_fp8.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_fp8.py deleted file mode 100644 index b269fa6a4cee316e8299ecc86c3e7594b336b499..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_fp8.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Optional - -import torch - -import quantization as ops - -from .marlin_utils import marlin_make_workspace, marlin_permute_scales - - -def is_fp8_marlin_supported(): - capability = torch.cuda.get_device_capability() - capability = capability[0] * 10 + capability[1] - return capability >= 80 - - -def apply_fp8_marlin_linear( - input: torch.Tensor, - weight: torch.Tensor, - weight_scale: torch.Tensor, - workspace: torch.Tensor, - size_n: int, - size_k: int, - bias: Optional[torch.Tensor], -) -> torch.Tensor: - # For GPUs that lack FP8 hardware support, we can leverage the - # Marlin kernel for fast weight-only FP8 quantization - - reshaped_x = input.reshape(-1, input.shape[-1]) - out_shape = input.shape[:-1] + (size_n,) - - output = ops.fp8_marlin_gemm( - a=reshaped_x, - b_q_weight=weight, - b_scales=weight_scale, - workspace=workspace, - num_bits=8, - size_m=reshaped_x.shape[0], - size_n=size_n, - size_k=size_k, - ) - - if bias is not None: - output.add_(bias) # In-place add - - return output.reshape(out_shape) - - -def prepare_fp8_layer_for_marlin( - layer: torch.nn.Module, strategy: str = "tensor" -) -> None: - part_size_n = layer.output_size_per_partition - part_size_k = layer.input_size_per_partition - - device = layer.weight.device - - # WORKSPACE - layer.workspace = marlin_make_workspace(part_size_n, device) - - # WEIGHT - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=pack_fp8_to_int32(layer.weight), - perm=torch.empty(0, dtype=torch.int, device=device), - size_k=part_size_k, - size_n=part_size_n, - num_bits=8, - ) - layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) - - # WEIGHT SCALES - scales = layer.weight_scale.to(layer.orig_dtype) - # Permute scales - marlin_scales = marlin_permute_scales( - s=scales, size_k=part_size_k, size_n=part_size_n, group_size=-1 - ) - layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) - - -def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: - """ - Repack FP8 weights to gptq format (packed int32 elements) - """ - assert fp8_tensor.dtype == torch.float8_e4m3fn - assert fp8_tensor.shape[0] % 4 == 0 - - # Reshape to prepare for packing - reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) - - # Convert fp8 to uint8 (byte) representation - byte_tensor = reshaped.view(torch.uint8) - - # Pack 4 uint8 values into one int32 - packed = ( - byte_tensor[:, 0].to(torch.int32) - | (byte_tensor[:, 1].to(torch.int32) << 8) - | (byte_tensor[:, 2].to(torch.int32) << 16) - | (byte_tensor[:, 3].to(torch.int32) << 24) - ) - - return packed.view(fp8_tensor.shape[0] // 4, *fp8_tensor.shape[1:]).contiguous() diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test.py deleted file mode 100644 index 7d4f5f3cfbb872bf7b32e0972d6143b43f354a5e..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test.py +++ /dev/null @@ -1,162 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -from typing import List, Optional - -import numpy as np -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils import GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points -from .quant_utils import ( - get_pack_factor, - gptq_quantize_weights, - quantize_weights, - sort_weights, -) - - -class MarlinWorkspace: - - def __init__(self, out_features, min_thread_n, max_parallel): - assert ( - out_features % min_thread_n == 0 - ), "out_features = {} is undivisible by min_thread_n = {}".format( - out_features, min_thread_n - ) - - max_workspace_size = (out_features // min_thread_n) * max_parallel - - self.scratch = torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") - - -def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): - assert q_w.shape == (size_k, size_n) - assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" - assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" - - # Permute weights to 16x64 marlin tiles - q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) - q_w = q_w.permute((0, 2, 1, 3)) - q_w = q_w.reshape((size_k // tile, size_n * tile)) - - q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) - - return q_w - - -def marlin_weights(q_w, size_k, size_n, num_bits, perm): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(np.uint32) - - q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) - - return q_packed - - -def get_weight_perm(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm_list) - - if num_bits == 4: - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = np.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, size_n = w.shape - num_bits = quant_type.size_bits - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Quantize (and apply act_order if provided) - w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order, test_perm - ) - - # For act_order, sort the "weights" and "g_idx" so that group ids are - # increasing - sort_indices = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) - - # Reformat to marlin - weight_perm = get_weight_perm(num_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list - - -def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Detect num groups - assert size_k % group_size == 0 - num_groups = size_k // group_size - - # Quantize with zp - w_ref, q_w, s, zp = quantize_weights(w, quant_type, group_size, zero_points=True) - - # Reformat to marlin - weight_perm = get_weight_perm(quant_type.size_bits) - marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, weight_perm) - marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) - marlin_zp = marlin_zero_points(zp, num_groups, size_n, quant_type.size_bits) - - # Create result - res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test_24.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test_24.py deleted file mode 100644 index 927fa9016ba25f381c09d768db0c468066193a76..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test_24.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Utility functions used for tests and benchmarks""" - -import random -from typing import List - -import numpy -import torch - -from quantization.scalar_type import ScalarType - -from .marlin_utils_test import marlin_weights -from .quant_utils import gptq_quantize_weights - - -# This is PyTorch implementation of main part of reorder_meta() -# function, from tools/util/include/cutlass/util/host_reorder.h file -# of CUTLASS source tree. Furthermore, CUTLASS template for sparse -# GEMM decides upon layout of this matrix, and at the moment for the -# sparse GEMM executed on tensor cores, this is layout described by -# ColumnMajorInterleaved<2> data structure, in -# include/cutlass/layout/matrix.h of CUTLASS source tree. The -# reordering of meta matrix into meta_reordered matrix calculated -# according to these segments of CUTLASS code is re-implemented here. -# Note that this calculation produces offsets for scattering metadata -# matrix elements into reordered metadata matrix elements (or, -# equivalently, for gathering reordered metadata matrix element back -# into metadata matrix elements). -def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device): - dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) - dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) - - # Reorder the rows, then swizzle the 2x2 blocks. - group_x = 64 - group_y = 32 if meta_dtype.itemsize == 2 else 16 - - dst_rows = ( - dst_rows // group_x * group_x - + (dst_rows % 2) * 2 - + (dst_rows % 8) // 4 - + ((dst_rows % group_y) % 4) // 2 * 32 - + ((dst_rows % group_x) // 8) * 4 - ) - - topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) - bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) - dst_rows += topright - bottomleft - dst_cols -= topright - bottomleft - - # Assumed that meta tensor is to be stored in CUTLASS - # InterleavedColumnMajor layout, and reverse engineered - # corresponding code to store values into this tensor. - interleave = 2 - cols_maj = dst_cols // interleave - cols_min = dst_cols % interleave - return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1) - - -# This function converts dense matrix into sparse semi-structured -# representation, producing "compressed" matrix, in the layout used by -# CUTLASS backend, and corresponding metadata matrix. -def sparse_semi_structured_from_dense_cutlass(dense): - if dense.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = dense.shape - device = dense.device - - meta_dtype = torch.int8 - if dense.dtype == torch.int8: - meta_dtype = torch.int32 - elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: - meta_dtype = torch.int16 - else: - raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - if quadbits_per_meta_elem not in (4, 8): - raise RuntimeError("Invalid number of elements per meta element calculated") - - if meta_dtype == torch.int32: - if m % 16 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 16" - ) - else: - if m % 32 != 0: - raise RuntimeError( - f"Number of rows of dense matrix {m} must be divisible by 32" - ) - if k % (4 * quadbits_per_meta_elem) != 0: - raise RuntimeError( - f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 - ) - - if dense.dtype != torch.float: - ksparse = 4 - dense_4 = dense.view(-1, k // ksparse, ksparse) - m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) - else: - ksparse = 2 - dense_2 = dense.view(-1, k // ksparse, ksparse) - m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) - meta_ncols = k // (ksparse * quadbits_per_meta_elem) - - # Encoding quadruples of True/False values as follows: - # [True, True, False, False] -> 0b0100 - # [True, False, True, False] -> 0b1000 - # [False, True, True, False] -> 0b1001 - # [True, False, False, True ] -> 0b1100 - # [False, True, False, True ] -> 0b1101 - # [False, False, True, True ] -> 0b1110 - # Thus, lower two bits in the encoding are index of the True value - # at the lowest index in the quadruple, and the higher two bits in - # the encoding are index of the other True value in the quadruple. - # In case there are less than two True values, than False value or - # values at some index or indices are considered True for the - # encoding. In case there are more than two True values, then the - # excess True value(s) at some indices are considered False for - # the encoding. The exact encodings used for these cases are as - # follows: - # [False, False, False, False] -> 0b1110 - # [False, False, False, True ] -> 0b1110 - # [False, False, True, False] -> 0b1110 - # [False, True, False, False] -> 0b1001 - # [False, True, True, True ] -> 0b1101 - # [True, False, False, False] -> 0b1000 - # [True, False, True, True ] -> 0b1100 - # [True, True, False, True ] -> 0b0100 - # [True, True, True, False] -> 0b0100 - # [True, True, True, True ] -> 0b0100 - # These particular encodings are chosen, with the help of Espresso - # logic minimizer software, for the purpose of minimization of - # corresponding Boolean functions, that translate non-zero flags - # into encoding bits. Note also possible choices for the first - # and last of these encodings were limited only to (0b0100, - # 0b1110), in order to produce valid encodings for 1:2 sparsity - # case. - - expr0 = m0 & m1 - expr1 = ~m0 & m1 - expr2 = ~m0 & ~m1 - bit0 = expr1 - bit1 = expr2 - bit2 = expr0 | expr2 | m3 - bit3 = expr1 | ~m1 - idxs0 = bit0 | (bit1.to(torch.int64) << 1) - idxs1 = bit2 | (bit3.to(torch.int64) << 1) - - if dense.dtype != torch.float: - sparse0 = dense_4.gather( - -1, idxs0.unsqueeze(-1) - ) # type: ignore[possibly-undefined] - sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) - sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) - else: - sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view( - m, k // 2 - ) # type: ignore[possibly-undefined] - - meta_4 = idxs0 | (idxs1 << 2) - meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) - - if quadbits_per_meta_elem == 4: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - ) - elif quadbits_per_meta_elem == 8: - meta = ( - meta_n[:, :, 0] - | (meta_n[:, :, 1] << 4) - | (meta_n[:, :, 2] << 8) - | (meta_n[:, :, 3] << 12) - | (meta_n[:, :, 4] << 16) - | (meta_n[:, :, 5] << 20) - | (meta_n[:, :, 6] << 24) - | (meta_n[:, :, 7] << 28) - ) - - # Reorder meta tensor elements. - meta_reordered = meta.new_empty( - (m * meta_ncols,) - ) # type: ignore[possibly-undefined] - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) - - return (sparse, meta_reordered.view(m, meta_ncols)) - - -# This function performs reverse of the function above - it -# reconstructs dense matrix from a pair of "compressed" matrix, given -# in the layout used by CUTLASS backend, and accompanying metadata -# matrix. -def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): - if sparse.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 - ) - - m, k = sparse.shape - device = sparse.device - - if meta_reordered.dim() != 2: - raise RuntimeError( - f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 - ) - if meta_reordered.device != device: - raise RuntimeError( - f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 - ) - - meta_dtype = meta_reordered.dtype - if meta_dtype not in (torch.int16, torch.int32): - raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") - quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 - - ksparse = 4 if sparse.dtype != torch.float else 2 - - meta_nrows, meta_ncols = meta_reordered.shape - if meta_nrows != m: - raise RuntimeError( - f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 - ) - if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: - raise RuntimeError( - f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 - "expected according to the number of columns of meta matrix" - ) - - # Undo meta tensor elements reordering. - meta_offsets = _calculate_meta_reordering_scatter_offsets( - m, meta_ncols, meta_dtype, device - ) - meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols) - - # Unpack sparse tensor back to original dense tensor, using - # information provided by meta tensor. Note that torch.float - # datatype is handled pretty much the same as - # torch.half/torch.bfloat16, as metadata for a pair of torch.float - # value is encoded as if underlying 8 bytes contain four - # torch.half/torch.bfloat16 values, where either first two or last - # two are zeros. - meta_2 = torch.empty( - (m, meta_ncols, 2 * quadbits_per_meta_elem), - dtype=meta_dtype, - device=device, - ) - if quadbits_per_meta_elem == 4: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - elif quadbits_per_meta_elem == 8: - meta_2[:, :, 0] = meta & 0b11 - meta_2[:, :, 1] = (meta >> 2) & 0b11 - meta_2[:, :, 2] = (meta >> 4) & 0b11 - meta_2[:, :, 3] = (meta >> 6) & 0b11 - meta_2[:, :, 4] = (meta >> 8) & 0b11 - meta_2[:, :, 5] = (meta >> 10) & 0b11 - meta_2[:, :, 6] = (meta >> 12) & 0b11 - meta_2[:, :, 7] = (meta >> 14) & 0b11 - meta_2[:, :, 8] = (meta >> 16) & 0b11 - meta_2[:, :, 9] = (meta >> 18) & 0b11 - meta_2[:, :, 10] = (meta >> 20) & 0b11 - meta_2[:, :, 11] = (meta >> 22) & 0b11 - meta_2[:, :, 12] = (meta >> 24) & 0b11 - meta_2[:, :, 13] = (meta >> 26) & 0b11 - meta_2[:, :, 14] = (meta >> 28) & 0b11 - meta_2[:, :, 15] = (meta >> 30) & 0b11 - - dense_offsets = meta_2.view(-1) + ( - torch.arange(0, 2 * m * k // ksparse, device=device) * 4 - ).view(-1, 1).repeat(1, 2).view(-1) - - dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device) - if sparse.dtype != torch.float: - # dense.scatter_(0, dense_offsets, sparse.view(-1)) - dense.scatter_(0, dense_offsets, sparse.reshape(-1)) - else: - dense.view(torch.half).scatter_( - 0, dense_offsets, sparse.view(torch.half).view(-1) - ) - - return dense.view(m, 2 * k) - - -def mask_creator(tensor): - """ - Class for creating N:M sparsity masks. - Masks will be created using the N:M ratio, where for every block of - M weights, N will be pruned based on ranked weight value. Each mask - will correspond to the given tensor. - - :param N: The number of weights in a group to keep - :param M: The size of a weight group - """ - N = 2 - M = 4 - - mask = None - # for i, tensor in enumerate(tensors): - if tensor.numel() % M != 0: - raise ValueError( - f"Tensor of size {tensor.shape} can't be evenly divided into " f"{M} groups" - ) - - num_groups = tensor.numel() // M - - # N:M sparsity for linear layers - tensor_temp = tensor.detach().abs().reshape(num_groups, M) - index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)] - - w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) - mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) - - return mask - - -def inject_24(w, size_k, size_n): - assert w.shape == (size_k, size_n) - - mask = mask_creator(w.t()).t().cuda().bool() - - return (mask * w).contiguous(), mask.contiguous() - - -def check_24(w, num_rows_to_sample=50, _verbose=False): - BLOCK_SIZE = 4 - MAX_NON_ZEROS = 2 - - w = w.t().contiguous() - - print("check_24: w.shape = {}".format(w.shape)) - - num_rows, num_cols = w.shape - sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) - if _verbose: - print(f"Sampled row idxs = {sampled_row_idxs}") - - total_segments = 0 - non_24_segments = 0 - for i in sampled_row_idxs: - for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): - total_segments += 1 - block = w[i, j : j + BLOCK_SIZE] - num_nonzero = torch.count_nonzero(block) - if num_nonzero > MAX_NON_ZEROS: - print("i = {} j = {} block = {}".format(i, j, block)) - non_24_segments += 1 - - print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") - - -def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): - assert q_24.shape == (size_k, size_n) - - # Remove bias to normalize over 0 - q_24_no_zp = q_24 - wtype.bias - - # Compress - q_24_no_zp = q_24_no_zp.t().contiguous() - q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp) - q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() - - # Restore bias - q_24_comp = q_24_no_zp_comp + wtype.bias - - # Resize meta to its actual shape (without moving any data) - meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) - - return q_24_comp, meta - - -def get_scale_perms_24(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) - scale_perm_single: List[int] = [] - for i in range(8): - scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) - return scale_perm, scale_perm_single - - -def get_weight_perm_24(num_bits: int): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - col_o = col // 2 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block) - for j in range(4): - perm_list.extend([p + 1 * j for p in perm1]) - perm = numpy.array(perm_list) - - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_permute_scales_24( - s: torch.Tensor, size_k: int, size_n: int, group_size: int -) -> torch.Tensor: - - scale_perm, scale_perm_single = get_scale_perms_24() - if group_size < size_k and group_size != -1: - s = s.reshape((-1, len(scale_perm)))[:, scale_perm] - else: - s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] - s = s.reshape((-1, size_n)).contiguous() - - return s - - -def marlin_24_quantize( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - # Inject 2:4 sparsity - w_24, mask_24 = inject_24(w, size_k, size_n) - - # Quantize - w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( - w_24, quant_type, group_size, act_order=False - ) - - # Compress quantized weight - q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type) - size_k_comp = size_k // 2 - - # Reformat to marlin - weight_perm = get_weight_perm_24(quant_type.size_bits) - marlin_24_q_w_comp = marlin_weights( - q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm - ) - marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) - - # Create result - res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py deleted file mode 100644 index cb58eb945836393c58c53f5c6d702d53861c33f9..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/marlin_utils_test_qqq.py +++ /dev/null @@ -1,125 +0,0 @@ -from typing import List - -import numpy -import torch - -from .marlin_utils_test import marlin_permute_weights -from .quant_utils import get_pack_factor, qqq_quantize_weights - - -def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): - # Permute - q_w = marlin_permute_weights(q_w, size_k, size_n, perm) - - # Pack - pack_factor = get_pack_factor(num_bits) - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), - dtype=numpy.uint32) - if group_size == size_k: - for i in range(pack_factor): - q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i - else: - for i in range(pack_factor): - q_packed |= q_w[:, i::pack_factor] << num_bits * i - - q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) - - return q_packed - - -def get_qqq_scale_perms(): - scale_perm: List[int] = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single: List[int] = [] - for i in range(4): - scale_perm_single.extend( - [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return scale_perm, scale_perm_single - - -# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 -def get_qqq_weight_perm(num_bits: int, quant_type: str): - perm_list: List[int] = [] - for i in range(32): - perm1: List[int] = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 4 * (i % 4), - 4 * (i % 4) + 1, - 4 * (i % 4) + 2, - 4 * (i % 4) + 3, - ]: - perm1.append(16 * row + col + 8 * block) - for j in range(4): - perm_list.extend([p + 256 * j for p in perm1]) - - perm = numpy.array(perm_list) - - assert quant_type in ["per-channel", - "per-group"], "not supported quantization type" - if num_bits == 4: - if quant_type == "per-channel": - interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) - else: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - else: - raise Exception("num_bits must be 4, got {}".format(num_bits)) - - perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() - perm = torch.from_numpy(perm) - return perm - - -def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): - scale_perm, scale_perm_single = get_qqq_scale_perms() - if group_size < size_k and group_size != -1: - s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_group = s_group.reshape((-1, size_n)).contiguous() - else: - s_channel = s_channel.reshape( - (-1, len(scale_perm_single)))[:, scale_perm_single] - s_channel = s_channel.reshape((-1, size_n)).contiguous() - - return s_group, s_channel - - -def marlin_qqq_quantize( - w: torch.Tensor, - num_bits: int, - group_size: int, -): - size_k, size_n = w.shape - - # Normalize group_size - if group_size == -1: - group_size = size_k - assert group_size <= size_k - quant_type = "per-channel" if group_size == size_k else "per-group" - - # Quantize - w_ref, q_w, s_group, s_channel = qqq_quantize_weights( - w, num_bits, group_size) - - # Reformat to marlin_qqq - weight_perm = get_qqq_weight_perm(num_bits, quant_type) - marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, - weight_perm, group_size) - marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( - s_group, s_channel, size_k, size_n, group_size) - - # Create result - res_list = [ - w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel - ] - for i in range(len(res_list)): - res_list[i] = res_list[i].to(w.device) - - return res_list diff --git a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/quant_utils.py b/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/quant_utils.py deleted file mode 100644 index d97e03913fa5980e0be73b160088c8e4f5f49a52..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-rocm63-x86_64-linux/quantization/utils/quant_utils.py +++ /dev/null @@ -1,470 +0,0 @@ -"""This file is used for /tests and /benchmarks""" - -from typing import List, Optional - -import numpy -import torch - -from quantization.scalar_type import ScalarType, scalar_types - -SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] -SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] - -MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] - -# Note: this is a hack. We should update each model to register the -# stacked params and get it from there instead in a future PR. -# fused_name: List[shard_name] -FUSED_LAYER_NAME_MAPPING = { - "qkv_proj": ["q_proj", "k_proj", "v_proj"], - "gate_up_proj": ["gate_proj", "up_proj"], -} - - -def pack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - assert w_q_perm.shape[-1] % pack_factor == 0 - new_shape_perm[-1] //= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i - - return res.permute(inv_perm) - - -def unpack_quantized_values_into_int32( - w_q: torch.Tensor, wtype: ScalarType, packed_dim: int = 0 -): - # move dim to pack to the end - perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) - inv_perm = tuple(perm.index(i) for i in range(len(perm))) - w_q_perm = w_q.permute(perm) - - pack_factor = 32 // wtype.size_bits - mask = (1 << wtype.size_bits) - 1 - - new_shape_perm = list(w_q_perm.shape) - new_shape_perm[-1] *= pack_factor - - res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) - for i in range(pack_factor): - res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask - - return res.permute(inv_perm) - - -def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool: - # prefix: model.layers.0.self_attn.q_proj - # proj_name: q_proj - proj_name = prefix.split(".")[-1] - if proj_name in FUSED_LAYER_NAME_MAPPING: - shard_prefixes = [ - prefix.replace(proj_name, shard_proj_name) - for shard_proj_name in FUSED_LAYER_NAME_MAPPING[proj_name] - ] - - is_skipped = None - for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers - - if is_skipped is None: - is_skipped = is_shard_skipped - elif is_shard_skipped != is_skipped: - raise ValueError( - f"Detected some but not all shards of {prefix} " - "are quantized. All shards of fused layers " - "to have the same precision." - ) - else: - is_skipped = prefix in ignored_layers - - assert is_skipped is not None - return is_skipped - - -def get_pack_factor(num_bits): - assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" - return 32 // num_bits - - -def permute_rows( - q_w: torch.Tensor, - w_ref: torch.Tensor, - group_size: int, - test_perm: Optional[torch.Tensor] = None, -): - assert q_w.shape == w_ref.shape - - orig_device = q_w.device - k_size, _ = q_w.shape - - g_idx = torch.zeros((k_size,), dtype=torch.int32) - for i in range(k_size): - g_idx[i] = i // group_size - - # Simulate act_order by doing a random permutation on K - rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) - - g_idx = g_idx[rand_perm].contiguous() - q_w = q_w[rand_perm, :].contiguous() - w_ref = w_ref[rand_perm, :].contiguous() - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - rand_perm.to(device=orig_device), - ) - - -def quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: Optional[int], - zero_points: bool = False, - ref_zero_points_after_scales: bool = False, -): - assert ( - quant_type.is_integer() - ), "Floating point quantization may work but has not been tested" - assert not zero_points or group_size is not None, ( - "to have group zero points, group_size must be provided " - "(-1 group_size is channelwise)" - ) - - orig_device = w.device - orig_type = w.dtype - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - - if group_size == -1: - group_size = size_k - - # Reshape to [groupsize, -1] - if group_size is not None and group_size < size_k: - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - # Compute scale for each group - max_val = torch.max(w, 0, keepdim=True).values - min_val = torch.min(w, 0, keepdim=True).values - - max_q_val = quant_type.max() - min_q_val = quant_type.min() - - w_s = torch.Tensor([1.0]).to(w.device) # unscaled case - maybe_w_zp = None - if group_size is not None: - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = ( - torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() - ) - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), - ) - - # Quantize - w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) - w_q = torch.clamp(w_q, min_q_val, max_q_val) - - # Compute ref (dequantized) - # For some kernels (namely Machete) the zero-points are applied after the - # scales are applied, for this case computing the reference in similar way - # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and maybe_w_zp is not None: - w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s - else: - w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s - - if quant_type.has_bias(): - w_q += quant_type.bias - - # Restore original shapes - if group_size is not None and group_size < size_k: - - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - w_q = reshape_w(w_q) - w_ref = reshape_w(w_ref) - w_s = w_s.reshape((-1, size_n)).contiguous() - - if maybe_w_zp is not None: - maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() - maybe_w_zp = maybe_w_zp.to(device=orig_device) - - return ( - w_ref.to(device=orig_device), - w_q.to(device=orig_device), - w_s if group_size is not None else None, - maybe_w_zp, - ) - - -def gptq_quantize_weights( - w: torch.Tensor, - quant_type: ScalarType, - group_size: int, - act_order: bool, - test_perm: Optional[torch.Tensor] = None, -): - size_k, _ = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - quant_type in SUPPORTED_GPTQ_QUANT_TYPES - ), f"Unsupported gptq type = {quant_type}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) - - # Apply act_order - g_idx = torch.empty(0, dtype=torch.int, device=w.device) - rand_perm = torch.empty(0, dtype=torch.int, device=w.device) - if act_order: - assert ( - group_size < size_k - ), "For act_order, groupsize = {} must be less than size_k = {}".format( - group_size, size_k - ) - - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) - - return w_ref, w_q, w_s, g_idx, rand_perm - - -# QQQ employs different quant schemes for per-group and -# per-channel quantization. -def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): - orig_device = w.device - size_k, size_n = w.shape - - assert w.is_floating_point(), "w must be float" - assert ( - num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS - ), f"Unsupported num_bits = {num_bits}" - assert group_size in SUPPORTED_GROUP_SIZES + [ - size_k - ], f"Unsupported groupsize = {group_size}" - - if group_size == -1: - group_size = size_k - assert group_size <= size_k - - if group_size < size_k: - # Reshape to [groupsize, -1] - w = w.reshape((-1, group_size, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((group_size, -1)) - - max_q_val = 2**num_bits - 1 - half_q_val = (max_q_val + 1) // 2 - - # Compute scale for each group - s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_group *= 2 / max_q_val # 2 => symmetric - - # Quantize - q_w = torch.round(w / s_group).int() - q_w += half_q_val - q_w = torch.clamp(q_w, 0, max_q_val) - # Compute ref (dequantized) - w_ref = (q_w - half_q_val).half() * s_group - - # Restore original shapes - def reshape_w(w): - w = w.reshape((group_size, -1, size_n)) - w = w.permute(1, 0, 2) - w = w.reshape((size_k, size_n)).contiguous() - return w - - q_w = reshape_w(q_w) - w_ref = reshape_w(w_ref) - - # Compute int8 quantization scale for each channel - s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] - s_channel /= 127.0 - t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) - w_ref = t_int8.half() * s_channel - s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) - - # Fuse scales - s_group = (s_group.reshape(-1, size_n).contiguous() / s_channel).to( - dtype=torch.half - ) - else: - max_q_val = 2 ** (num_bits - 1) - 1 - - # Compute scale for each channel - s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] - s_channel /= max_q_val - - # Quantize - q_w = torch.round(w / s_channel).int() - q_w = torch.clamp(q_w, -max_q_val, max_q_val) - # Compute ref (dequantized) - w_ref = q_w.half() * s_channel - - s_group = torch.tensor([], dtype=torch.half) - # div 2 ** (8 - self.bits)) to offset right shift in unpacking - s_channel /= 2 ** (8 - num_bits) - s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) - - return ( - w_ref.to(device=orig_device), - q_w.to(device=orig_device), - s_group.to(device=orig_device), - s_channel.to(device=orig_device), - ) - - -def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): - orig_device = q_w.device - - sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx - - g_idx = g_idx[sort_indices].contiguous() - q_w = q_w[sort_indices, :].contiguous() - - return ( - q_w.to(device=orig_device), - g_idx.to(device=orig_device), - sort_indices.to(device=orig_device), - ) - - -def pack_rows( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_k % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[i::pack_factor, :] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - return q_res - - -def pack_cols( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - - orig_device = q_w.device - - q_w = q_w.cpu().numpy().astype(numpy.uint32) - - q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) - - for i in range(pack_factor): - q_res |= q_w[:, i::pack_factor] << num_bits * i - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def unpack_cols( - packed_q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - pack_factor = get_pack_factor(num_bits) - assert size_n % pack_factor == 0 - assert packed_q_w.shape == ( - size_k, - size_n // pack_factor, - ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( - packed_q_w.shape, size_k, size_n, pack_factor - ) - - orig_device = packed_q_w.device - - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) - q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) - - mask = (1 << num_bits) - 1 - for i in range(pack_factor): - vals = packed_q_w_cpu & mask - packed_q_w_cpu >>= num_bits - q_res[:, i::pack_factor] = vals - - q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) - q_res = q_res.contiguous() - - return q_res - - -def gptq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - return pack_rows(q_w, num_bits, size_k, size_n) - - -def awq_pack( - q_w: torch.Tensor, - num_bits: int, - size_k: int, - size_n: int, -): - assert q_w.shape == (size_k, size_n) - - # Interleave column dim (for the dequantize code) and pack it to int32 - if num_bits == 4: - interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) - elif num_bits == 8: - interleave = numpy.array([0, 2, 1, 3]) - else: - raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) - - q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() - q_w = q_w.reshape((-1, size_n)).contiguous() - - return pack_cols(q_w, num_bits, size_k, size_n)