File size: 2,348 Bytes
72268ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import os
from transformers import AutoConfig
from awq.models import *
from awq.models.base import BaseAWQForCausalLM

AWQ_CAUSAL_LM_MODEL_MAP = {
    "mpt": MptAWQForCausalLM,
    "llama": LlamaAWQForCausalLM,
    "opt": OptAWQForCausalLM,
    "RefinedWeb": FalconAWQForCausalLM,
    "RefinedWebModel": FalconAWQForCausalLM,
    "falcon": FalconAWQForCausalLM,
    "bloom": BloomAWQForCausalLM,
    "gptj": GPTJAWQForCausalLM,
    "gpt_bigcode": GptBigCodeAWQForCausalLM,
    "mistral": MistralAWQForCausalLM
}

def check_and_get_model_type(model_dir, trust_remote_code=True):
    config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
    if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
        raise TypeError(f"{config.model_type} isn't supported yet.")
    model_type = config.model_type
    return model_type

class AutoAWQForCausalLM:
    def __init__(self):
        raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
                               'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
    
    @classmethod
    def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
                              device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM:
        model_type = check_and_get_model_type(model_path, trust_remote_code)

        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
            model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors,
            device_map=device_map, **model_init_kwargs
        )

    @classmethod
    def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
                       trust_remote_code=True, fuse_layers=True,
                       batch_size=1, safetensors=False,
                       max_memory=None, offload_folder=None) -> BaseAWQForCausalLM:
        os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
        model_type = check_and_get_model_type(quant_path, trust_remote_code)

        return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
            quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, 
            fuse_layers=fuse_layers, safetensors=safetensors, 
            max_memory=max_memory, offload_folder=offload_folder
        )