File size: 10,911 Bytes
72268ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
import os
import gc
import json
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import List, Union, Dict
from safetensors.torch import save_file
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.utils import simple_dispatch_model
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config):
super().__init__()
self.model:PreTrainedModel = model
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
self.quant_config: Dict = quant_config
def to(self, device: str):
return self.model.to(device)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def generate(self, *args, **kwargs):
with torch.inference_mode():
return self.model.generate(*args, **kwargs)
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
quantizer = AwqQuantizer(
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
quant_config["version"], calib_data, split, text_column
)
quantizer.quantize()
self.is_quantized = True
@staticmethod
def fuse_layers(model, quant_config):
pass
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model files with empty state dict
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin')
# model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
self.model.state_dict(),
max_shard_size=shard_size,
weights_name=model_name
)
for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))
# Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file:
file.write(json.dumps(self.quant_config, indent=4))
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True, safetensors=False, device_map=None,
**model_init_kwargs):
# Get weights path and quant config
model_weights_path, config, quant_config = self._load_config(
self, model_path, '', safetensors, trust_remote_code=trust_remote_code
)
if device_map is None:
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Get device map
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model
# If not quantized, must load with AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
model_weights_path,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
use_safetensors=safetensors,
**model_init_kwargs
)
model.eval()
return self(model, model_type, is_quantized=False, quant_config=quant_config)
@classmethod
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=False, is_quantized=True,
fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version,
trust_remote_code, max_new_tokens=max_new_tokens
)
# [STEP 3] Load model
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config["version"])
model.tie_weights()
# Get device map
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype
)
# Load checkpoint
load_checkpoint_in_model(
model,
checkpoint=model_weights_path,
device_map=device_map,
offload_folder=offload_folder,
dtype=torch_dtype
)
# Dispath to devices
if fuse_layers:
self.fuse_layers(model, quant_config)
# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=offload_folder
)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_config(self, model_path, model_filename, safetensors=False,
version="GEMM", trust_remote_code=True, max_new_tokens=4096):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"])
else:
ignore_patterns.append("*.safetensors*")
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
if model_filename != '':
model_weights_path = model_path + f'/{model_filename}'
else:
model_weights_path = model_path
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())
if "version" not in quant_config.keys():
quant_config["version"] = version
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = getattr(config, self.max_new_tokens_key)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = max_new_tokens
return model_weights_path, config, quant_config
def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now."
# Get blocks of model
layers = self.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i]
# Get every linear layer in a block
named_linears = get_named_linears(layer)
# Replace activation functions
self._scale_activations(self, layer)
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if version == 'GEMM':
q_linear_module = WQLinear_GEMM
elif version == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def _scale_activations(self, layer):
scale_dict = self.get_act_for_scaling(layer)
if scale_dict['is_scalable']:
if not isinstance(scale_dict['scale_layer'], ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
|