|
import os |
|
import gc |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
from tqdm import tqdm |
|
from typing import List, Union, Dict |
|
from safetensors.torch import save_file |
|
from awq.modules.act import ScaledActivation |
|
from huggingface_hub import snapshot_download |
|
from awq.quantize.quantizer import AwqQuantizer |
|
from awq.utils.utils import simple_dispatch_model |
|
from transformers.modeling_utils import shard_checkpoint |
|
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV |
|
from awq.utils.module import get_named_linears, set_op_by_name |
|
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel |
|
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map |
|
|
|
class BaseAWQForCausalLM(nn.Module): |
|
def __init__(self, model, model_type, is_quantized, quant_config): |
|
super().__init__() |
|
self.model:PreTrainedModel = model |
|
self.model_type:str = model_type |
|
self.is_quantized:bool = is_quantized |
|
self.search_result = None |
|
self.quant_config: Dict = quant_config |
|
|
|
def to(self, device: str): |
|
return self.model.to(device) |
|
|
|
def forward(self, *args, **kwargs): |
|
return self.model(*args, **kwargs) |
|
|
|
def generate(self, *args, **kwargs): |
|
with torch.inference_mode(): |
|
return self.model.generate(*args, **kwargs) |
|
|
|
@torch.no_grad() |
|
def quantize(self, tokenizer=None, quant_config={}, |
|
calib_data: Union[str, List[str]]="pileval", |
|
split="train", text_column="text"): |
|
self.quant_config = quant_config |
|
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] |
|
|
|
quantizer = AwqQuantizer( |
|
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"], |
|
quant_config["version"], calib_data, split, text_column |
|
) |
|
quantizer.quantize() |
|
self.is_quantized = True |
|
|
|
@staticmethod |
|
def fuse_layers(model, quant_config): |
|
pass |
|
|
|
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"): |
|
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir |
|
|
|
|
|
class EmptyModule(nn.Module): |
|
def __init__(self): super(EmptyModule, self).__init__() |
|
def forward(self, x): return x |
|
|
|
|
|
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) |
|
|
|
|
|
os.remove(f'{save_dir}/pytorch_model.bin') |
|
|
|
|
|
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin' |
|
|
|
|
|
shards, index = shard_checkpoint( |
|
self.model.state_dict(), |
|
max_shard_size=shard_size, |
|
weights_name=model_name |
|
) |
|
|
|
for shard_file, shard in shards.items(): |
|
if safetensors: |
|
|
|
shard = {k: v.clone().contiguous() for k, v in shard.items()} |
|
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}) |
|
else: |
|
torch.save(shard, os.path.join(save_dir, shard_file)) |
|
|
|
|
|
if index is not None: |
|
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file: |
|
file.write(json.dumps(index, indent=4)) |
|
|
|
|
|
with open(f'{save_dir}/quant_config.json', 'w+') as file: |
|
file.write(json.dumps(self.quant_config, indent=4)) |
|
|
|
|
|
@classmethod |
|
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, |
|
trust_remote_code=True, safetensors=False, device_map=None, |
|
**model_init_kwargs): |
|
|
|
model_weights_path, config, quant_config = self._load_config( |
|
self, model_path, '', safetensors, trust_remote_code=trust_remote_code |
|
) |
|
|
|
if device_map is None: |
|
with init_empty_weights(): |
|
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) |
|
|
|
|
|
device_map = infer_auto_device_map( |
|
model, |
|
no_split_module_classes=[self.layer_type], |
|
dtype=torch_dtype |
|
) |
|
del model |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_weights_path, |
|
trust_remote_code=trust_remote_code, |
|
torch_dtype=torch_dtype, |
|
use_safetensors=safetensors, |
|
**model_init_kwargs |
|
) |
|
|
|
model.eval() |
|
|
|
return self(model, model_type, is_quantized=False, quant_config=quant_config) |
|
|
|
@classmethod |
|
def from_quantized(self, model_path, model_type, model_filename='', |
|
max_new_tokens=None, torch_dtype=torch.float16, |
|
trust_remote_code=True, safetensors=False, is_quantized=True, |
|
fuse_layers=False, version='GEMM', |
|
max_memory=None, offload_folder=None): |
|
|
|
model_weights_path, config, quant_config = self._load_config( |
|
self, model_path, model_filename, safetensors, version, |
|
trust_remote_code, max_new_tokens=max_new_tokens |
|
) |
|
|
|
|
|
with init_empty_weights(): |
|
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) |
|
|
|
|
|
self._load_quantized_modules(self, model, quant_config, quant_config["version"]) |
|
|
|
model.tie_weights() |
|
|
|
|
|
device_map = infer_auto_device_map( |
|
model, |
|
no_split_module_classes=[self.layer_type], |
|
max_memory=max_memory, |
|
dtype=torch_dtype |
|
) |
|
|
|
|
|
load_checkpoint_in_model( |
|
model, |
|
checkpoint=model_weights_path, |
|
device_map=device_map, |
|
offload_folder=offload_folder, |
|
dtype=torch_dtype |
|
) |
|
|
|
|
|
if fuse_layers: |
|
self.fuse_layers(model, quant_config) |
|
|
|
|
|
from accelerate import dispatch_model |
|
model = dispatch_model( |
|
model, |
|
device_map=device_map, |
|
offload_dir=offload_folder |
|
) |
|
|
|
|
|
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) |
|
|
|
def _load_config(self, model_path, model_filename, safetensors=False, |
|
version="GEMM", trust_remote_code=True, max_new_tokens=4096): |
|
|
|
if not os.path.isdir(model_path): |
|
ignore_patterns = ["*msgpack*", "*h5*"] |
|
if safetensors: |
|
ignore_patterns.extend(["*.pt*", "*.bin*"]) |
|
else: |
|
ignore_patterns.append("*.safetensors*") |
|
|
|
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) |
|
|
|
if model_filename != '': |
|
model_weights_path = model_path + f'/{model_filename}' |
|
else: |
|
model_weights_path = model_path |
|
|
|
|
|
|
|
quant_config_path = f'{model_path}/quant_config.json' |
|
if os.path.exists(quant_config_path): |
|
with open(quant_config_path, 'r') as file: |
|
quant_config = json.loads(file.read()) |
|
|
|
if "version" not in quant_config.keys(): |
|
quant_config["version"] = version |
|
else: |
|
|
|
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version} |
|
|
|
|
|
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'): |
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) |
|
config.max_new_tokens = getattr(config, self.max_new_tokens_key) |
|
else: |
|
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens |
|
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) |
|
config.max_new_tokens = max_new_tokens |
|
|
|
return model_weights_path, config, quant_config |
|
|
|
def _load_quantized_modules(self, model, quant_config, version): |
|
|
|
assert quant_config["zero_point"], "We only support zero_point quantization now." |
|
|
|
|
|
layers = self.get_model_layers(model) |
|
|
|
for i in tqdm(range(len(layers)), desc="Replacing layers..."): |
|
layer = layers[i] |
|
|
|
|
|
named_linears = get_named_linears(layer) |
|
|
|
|
|
self._scale_activations(self, layer) |
|
|
|
|
|
for name, module in named_linears.items(): |
|
if version == 'GEMM': |
|
q_linear_module = WQLinear_GEMM |
|
elif version == 'GEMV': |
|
q_linear_module = WQLinear_GEMV |
|
|
|
q_linear = q_linear_module.from_linear( |
|
module, |
|
quant_config['w_bit'], |
|
quant_config['q_group_size'], |
|
True |
|
) |
|
q_linear.to(next(layer.parameters()).device) |
|
set_op_by_name(layer, name, q_linear) |
|
|
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
@staticmethod |
|
def _scale_activations(self, layer): |
|
scale_dict = self.get_act_for_scaling(layer) |
|
|
|
if scale_dict['is_scalable']: |
|
if not isinstance(scale_dict['scale_layer'], ScaledActivation): |
|
param = next(layer.parameters()) |
|
|
|
|
|
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device) |
|
|
|
|
|
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like) |
|
set_op_by_name(layer, scale_dict['scale_name'], scaled_act) |
|
|