File size: 7,445 Bytes
c02bdcd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# From https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py
#
# Edited by fumiama.
import re
from contextlib import contextmanager
from typing import Dict
import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import torch
import transformers
from transformers.models.llama.modeling_llama import (
LlamaModel,
LlamaConfig,
)
from transformers.modeling_utils import _load_state_dict_into_model
from .patch import LlamaRMSNorm
@contextmanager
def replace_decoder(te_decoder_cls, llama_rms_norm_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = (
transformers.models.llama.modeling_llama.LlamaDecoderLayer
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
original_llama_rms_norm_cls = transformers.models.llama.modeling_llama.LlamaRMSNorm
transformers.models.llama.modeling_llama.LlamaRMSNorm = llama_rms_norm_cls
try:
yield
finally:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
original_llama_decoder_cls
)
transformers.models.llama.modeling_llama.LlamaRMSNorm = (
original_llama_rms_norm_cls
)
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
"""
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
Args:
config: LlamaConfig
args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
"""
def __init__(self, config, *args, **kwargs):
super().__init__(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
bias=False,
layernorm_epsilon=config.rms_norm_eps,
hidden_dropout=0,
attention_dropout=0,
fuse_qkv_params=False,
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
)
te_rope = RotaryPositionEmbedding(
config.hidden_size // config.num_attention_heads
)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
def forward(self, hidden_states, *args, attention_mask, **kwargs):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (
super().forward(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=self.te_rope_emb,
),
)
class TELlamaModel:
"""
LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
class is monkey-patched with `TELlamaDecoderLayer` class before
initializing the causal LM with `LlamaModel`.
Args:
config: LlamaConfig
"""
def __new__(cls, config: LlamaConfig):
with replace_decoder(
te_decoder_cls=TELlamaDecoderLayer, llama_rms_norm_cls=LlamaRMSNorm
):
model = LlamaModel(config)
return model
@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, torch.Tensor],
config: LlamaConfig,
):
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config)
# replace_params copies parameters relevant only to TransformerEngine
_replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
return vanilla_model
def _replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + "input_layernorm.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.query_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.key_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.value_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]
)
if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = (
hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:]
)
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
: config.intermediate_size
] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
config.intermediate_size :
] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = (
hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]
)
return all_layer_prefixes
|