Spaces:
Sleeping
Sleeping
import enum | |
from enum import Enum | |
from typing import Any, Dict, List, Optional | |
import torch | |
from torch.nn.parameter import Parameter | |
from vllm._C import ops | |
from vllm.model_executor.layers.linear import (LinearMethodBase, | |
set_weight_attrs) | |
from vllm.model_executor.layers.quantization.base_config import ( | |
QuantizationConfig) | |
class GPTQConfig(QuantizationConfig): | |
"""Config class for GPTQ. | |
Reference: https://arxiv.org/abs/2210.17323 | |
""" | |
def __init__( | |
self, | |
weight_bits: int, | |
group_size: int, | |
desc_act: bool, | |
) -> None: | |
self.weight_bits = weight_bits | |
self.group_size = group_size | |
self.desc_act = desc_act | |
self.pack_factor = 32 // self.weight_bits | |
# exllama kernel v1 only supports 4 bit | |
if self.weight_bits != 4: | |
raise ValueError( | |
"Currently, only 4-bit weight quantization is supported for " | |
f"GPTQ, but got {self.weight_bits} bits.") | |
def __repr__(self) -> str: | |
return (f"GPTQConfig(weight_bits={self.weight_bits}, " | |
f"group_size={self.group_size}, " | |
f"desc_act={self.desc_act})") | |
def get_name(cls) -> str: | |
return "gptq" | |
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | |
return [torch.half] | |
# Need to figure it out | |
def get_min_capability(cls) -> int: | |
return 60 | |
def get_config_filenames(cls) -> List[str]: | |
return ["quantize_config.json"] | |
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": | |
weight_bits = cls.get_from_keys(config, ["bits"]) | |
group_size = cls.get_from_keys(config, ["group_size"]) | |
desc_act = cls.get_from_keys(config, ["desc_act"]) | |
return cls(weight_bits, group_size, desc_act) | |
def get_linear_method(self) -> "GPTQLinearMethod": | |
return GPTQLinearMethod(self) | |
def get_scaled_act_names(self) -> List[str]: | |
return [] | |
class ExllamaState(Enum): | |
UNUSED = enum.auto() | |
UNINITIALIZED = enum.auto() | |
READY = enum.auto() | |
class GPTQLinearMethod(LinearMethodBase): | |
"""Linear method for GPTQ. | |
Args: | |
quant_config: The GPTQ quantization config. | |
""" | |
def __init__(self, quant_config: GPTQConfig): | |
self.quant_config = quant_config | |
def create_weights( | |
self, | |
input_size_per_partition: int, | |
output_size_per_partition: int, | |
input_size: int, | |
output_size: int, | |
params_dtype: torch.dtype, | |
) -> Dict[str, Any]: | |
del output_size # Unused. | |
if input_size_per_partition % self.quant_config.group_size != 0: | |
raise ValueError( | |
"The input size is not aligned with the quantized " | |
"weight shape. This can be caused by too large " | |
"tensor parallel size.") | |
if output_size_per_partition % self.quant_config.pack_factor != 0: | |
raise ValueError( | |
"The output size is not aligned with the quantized " | |
"weight shape. This can be caused by too large " | |
"tensor parallel size.") | |
if self.quant_config.group_size != -1: | |
group_size = self.quant_config.group_size | |
else: | |
group_size = input_size | |
exllama_state = ExllamaState.UNINITIALIZED | |
scale_and_zero_size = input_size // group_size | |
scale_and_zero_input_dim = None | |
if input_size != input_size_per_partition and self.quant_config.group_size != -1: | |
# For act-order models, we cannot use Exllama for row parallel layer | |
if self.quant_config.desc_act: | |
exllama_state = ExllamaState.UNUSED | |
else: | |
# we need to partition qzeros and scales for exllama kernel | |
scale_and_zero_size = input_size_per_partition // group_size | |
scale_and_zero_input_dim = 0 | |
qweight = Parameter( | |
torch.empty( | |
input_size_per_partition // self.quant_config.pack_factor, | |
output_size_per_partition, | |
device="cuda", | |
dtype=torch.int32, | |
), | |
requires_grad=False, | |
) | |
set_weight_attrs( | |
qweight, { | |
"input_dim": 0, | |
"output_dim": 1, | |
"packed_dim": 0, | |
"pack_factor": self.quant_config.pack_factor, | |
}) | |
g_idx = Parameter( | |
torch.tensor( | |
[ | |
i // self.quant_config.group_size | |
for i in range(input_size_per_partition) | |
], | |
device="cuda", | |
dtype=torch.int32, | |
), | |
requires_grad=False, | |
) | |
# Ignore warning from fused linear layers such as QKVParallelLinear. | |
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) | |
qzeros = Parameter( | |
torch.empty( | |
scale_and_zero_size, | |
output_size_per_partition // self.quant_config.pack_factor, | |
device="cuda", | |
dtype=torch.int32, | |
), | |
requires_grad=False, | |
) | |
set_weight_attrs( | |
qzeros, { | |
"input_dim": scale_and_zero_input_dim, | |
"output_dim": 1, | |
"packed_dim": 1, | |
"pack_factor": self.quant_config.pack_factor, | |
}) | |
scales = Parameter( | |
torch.empty( | |
scale_and_zero_size, | |
output_size_per_partition, | |
device="cuda", | |
dtype=params_dtype, | |
), | |
requires_grad=False, | |
) | |
set_weight_attrs(scales, { | |
"input_dim": scale_and_zero_input_dim, | |
"output_dim": 1, | |
}) | |
return { | |
"qweight": qweight, | |
"g_idx": g_idx, | |
"qzeros": qzeros, | |
"scales": scales, | |
"exllama_state": exllama_state, | |
} | |
def apply_weights(self, | |
weights: Dict[str, Any], | |
x: torch.Tensor, | |
bias: Optional[torch.Tensor] = None) -> torch.Tensor: | |
qweight = weights["qweight"] | |
out_shape = x.shape[:-1] + (qweight.shape[-1], ) | |
reshaped_x = x.reshape(-1, x.shape[-1]) | |
# exllama needs to shuffle the weight after the weight is loaded | |
# here we do the shuffle on first forward pass | |
if weights["exllama_state"] == ExllamaState.UNINITIALIZED: | |
if self.quant_config.desc_act: | |
weights["g_idx"] = torch.argsort(weights["g_idx"]).to( | |
torch.int) | |
else: | |
weights["g_idx"] = torch.empty((1, 1), device="meta") | |
weights["exllama_state"] = ExllamaState.READY | |
ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) | |
output = ops.gptq_gemm(reshaped_x, weights["qweight"], | |
weights["qzeros"], weights["scales"], | |
weights["g_idx"], | |
weights["exllama_state"] == ExllamaState.READY) | |
if bias is not None: | |
output = output + bias | |
return output.reshape(out_shape) | |