ChatTTS2 / ChatTTS /model /cuda /te_llama.py
zhengr's picture
init
c02bdcd
raw
history blame
7.45 kB
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# From https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/te_llama/te_llama.py
#
# Edited by fumiama.
import re
from contextlib import contextmanager
from typing import Dict
import transformer_engine as te
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import torch
import transformers
from transformers.models.llama.modeling_llama import (
LlamaModel,
LlamaConfig,
)
from transformers.modeling_utils import _load_state_dict_into_model
from .patch import LlamaRMSNorm
@contextmanager
def replace_decoder(te_decoder_cls, llama_rms_norm_cls):
"""
Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.
"""
original_llama_decoder_cls = (
transformers.models.llama.modeling_llama.LlamaDecoderLayer
)
transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls
original_llama_rms_norm_cls = transformers.models.llama.modeling_llama.LlamaRMSNorm
transformers.models.llama.modeling_llama.LlamaRMSNorm = llama_rms_norm_cls
try:
yield
finally:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = (
original_llama_decoder_cls
)
transformers.models.llama.modeling_llama.LlamaRMSNorm = (
original_llama_rms_norm_cls
)
class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
"""
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
similar to HF's `LlamaDecoderLayer` and easier to replace it in the code.
Args:
config: LlamaConfig
args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
"""
def __init__(self, config, *args, **kwargs):
super().__init__(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
bias=False,
layernorm_epsilon=config.rms_norm_eps,
hidden_dropout=0,
attention_dropout=0,
fuse_qkv_params=False,
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
)
te_rope = RotaryPositionEmbedding(
config.hidden_size // config.num_attention_heads
)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
def forward(self, hidden_states, *args, attention_mask, **kwargs):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
return (
super().forward(
hidden_states,
attention_mask=attention_mask,
rotary_pos_emb=self.te_rope_emb,
),
)
class TELlamaModel:
"""
LM created with `LlamaModel`. The underlying `LlamaDecoderLayer`
class is monkey-patched with `TELlamaDecoderLayer` class before
initializing the causal LM with `LlamaModel`.
Args:
config: LlamaConfig
"""
def __new__(cls, config: LlamaConfig):
with replace_decoder(
te_decoder_cls=TELlamaDecoderLayer, llama_rms_norm_cls=LlamaRMSNorm
):
model = LlamaModel(config)
return model
@classmethod
def from_state_dict(
cls,
state_dict: Dict[str, torch.Tensor],
config: LlamaConfig,
):
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
vanilla_model = cls(config)
# replace_params copies parameters relevant only to TransformerEngine
_replace_params(state_dict, vanilla_model.state_dict(), config)
# _load_state_dict_into_model copies parameters other than those in TransformerEngine
_load_state_dict_into_model(vanilla_model, state_dict, start_prefix="")
return vanilla_model
def _replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model
if layer_prefix + "input_layernorm.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"
].data[:] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.query_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.key_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
te_state_dict[
layer_prefix + "self_attention.layernorm_qkv.value_weight"
].data[:] = hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = (
hf_state_dict[layer_prefix + "self_attn.o_proj.weight"].data[:]
)
if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = (
hf_state_dict[layer_prefix + "post_attention_layernorm.weight"].data[:]
)
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
: config.intermediate_size
] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
config.intermediate_size :
] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = (
hf_state_dict[layer_prefix + "mlp.down_proj.weight"].data[:]
)
return all_layer_prefixes