|
from .base import BaseAWQForCausalLM |
|
from typing import Dict |
|
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention |
|
|
|
class FalconAWQForCausalLM(BaseAWQForCausalLM): |
|
layer_type = "FalconDecoderLayer" |
|
|
|
@staticmethod |
|
def fuse_layers(model: FalconForCausalLM, quant_config: Dict): |
|
fuser = FalconFuser(model) |
|
|
|
|
|
if model.config.num_attention_heads == 71: |
|
fuser.fuse_transformer() |
|
|
|
@staticmethod |
|
def get_model_layers(model: FalconForCausalLM): |
|
return model.transformer.h |
|
|
|
@staticmethod |
|
def get_act_for_scaling(module: OldFalconDecoderLayer): |
|
return dict( |
|
is_scalable=True, |
|
scale_name="mlp.act", |
|
scale_layer=module.mlp.act, |
|
scale_shape=module.mlp.dense_h_to_4h.out_features |
|
) |
|
|
|
@staticmethod |
|
def move_embed(model: FalconForCausalLM, device): |
|
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device) |
|
|
|
@staticmethod |
|
def get_layers_for_scaling(module: OldFalconDecoderLayer, input_feat, module_kwargs): |
|
layers = [] |
|
|
|
|
|
if module.config.num_attention_heads == 71: |
|
|
|
layers.append(dict( |
|
prev_op=module.input_layernorm, |
|
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value], |
|
inp=input_feat['self_attention.query_key_value'], |
|
module2inspect=module, |
|
kwargs=module_kwargs, |
|
)) |
|
|
|
|
|
else: |
|
|
|
layers.append(dict( |
|
prev_op=module.ln_attn, |
|
layers=[module.self_attention.query_key_value], |
|
inp=input_feat['self_attention.query_key_value'], |
|
module2inspect=module, |
|
kwargs=module_kwargs, |
|
)) |
|
|
|
|
|
layers.append(dict( |
|
prev_op=module.ln_mlp, |
|
layers=[module.mlp.dense_h_to_4h], |
|
inp=input_feat['mlp.dense_h_to_4h'], |
|
module2inspect=module, |
|
kwargs=module_kwargs, |
|
)) |
|
|
|
return layers |
|
|
|
from awq.modules.fused.model import FalconModel |
|
from awq.modules.fused.block import FalconDecoderLayer |
|
|
|
class FalconFuser: |
|
def __init__(self, model: FalconForCausalLM): |
|
self.model = model |
|
|
|
def fuse_transformer(self): |
|
blocks = [] |
|
|
|
module: OldFalconDecoderLayer |
|
for module in self.model.transformer.h: |
|
if module.config.num_attention_heads == 71: |
|
input_layernorm = module.input_layernorm |
|
ln_attn = None |
|
ln_mlp = None |
|
new_decoder_arch = False |
|
else: |
|
input_layernorm = None |
|
ln_attn = module.ln_attn |
|
ln_mlp = module.ln_mlp |
|
new_decoder_arch = True |
|
|
|
blocks.append(FalconDecoderLayer( |
|
hidden_size=module.config.hidden_size, |
|
n_heads=module.config.num_attention_heads, |
|
qkv_layer=module.self_attention.query_key_value, |
|
o_proj=module.self_attention.dense, |
|
mlp=module.mlp, |
|
dev=next(iter(module.state_dict().values())).device, |
|
max_seq_len=self.model.config.max_new_tokens, |
|
input_layernorm=input_layernorm, |
|
ln_attn=ln_attn, |
|
ln_mlp=ln_mlp, |
|
new_decoder_arch=new_decoder_arch |
|
)) |
|
|
|
self.model.transformer = FalconModel( |
|
self.model.config.vocab_size, |
|
blocks, |
|
self.model.transformer.word_embeddings, |
|
self.model.transformer.ln_f, |
|
) |