File size: 2,266 Bytes
6dc0c9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from dataclasses import dataclass, field
import os
from os.path import isdir, isfile
from pathlib import Path
import sys
from transformers import AutoTokenizer
@dataclass
class GptqConfig:
ckpt: str = field(
default=None,
metadata={
"help": "Load quantized model. The path to the local GPTQ checkpoint."
},
)
wbits: int = field(default=16, metadata={"help": "#bits to use for quantization"})
groupsize: int = field(
default=-1,
metadata={"help": "Groupsize to use for quantization; default uses full row."},
)
act_order: bool = field(
default=True,
metadata={"help": "Whether to apply the activation order GPTQ heuristic"},
)
def load_gptq_quantized(model_name, gptq_config: GptqConfig):
print("Loading GPTQ quantized model...")
try:
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
module_path = os.path.join(script_path, "../repositories/GPTQ-for-LLaMa")
sys.path.insert(0, module_path)
from llama import load_quant
except ImportError as e:
print(f"Error: Failed to load GPTQ-for-LLaMa. {e}")
print("See https://github.com/lm-sys/FastChat/blob/main/docs/gptq.md")
sys.exit(-1)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
# only `fastest-inference-4bit` branch cares about `act_order`
if gptq_config.act_order:
model = load_quant(
model_name,
find_gptq_ckpt(gptq_config),
gptq_config.wbits,
gptq_config.groupsize,
act_order=gptq_config.act_order,
)
else:
# other branches
model = load_quant(
model_name,
find_gptq_ckpt(gptq_config),
gptq_config.wbits,
gptq_config.groupsize,
)
return model, tokenizer
def find_gptq_ckpt(gptq_config: GptqConfig):
if Path(gptq_config.ckpt).is_file():
return gptq_config.ckpt
for ext in ["*.pt", "*.safetensors"]:
matched_result = sorted(Path(gptq_config.ckpt).glob(ext))
if len(matched_result) > 0:
return str(matched_result[-1])
print("Error: gptq checkpoint not found")
sys.exit(1)
|