Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Adapted from | |
# https://github.com/THUDM/ChatGLM2-6B | |
"""Inference-only ChatGLM model compatible with THUDM weights.""" | |
from typing import List, Optional, Tuple | |
import torch | |
from torch import nn | |
from torch.nn import LayerNorm | |
from vllm.model_executor.input_metadata import InputMetadata | |
from vllm.model_executor.layers.activation import SiluAndMul | |
from vllm.model_executor.layers.attention import PagedAttention | |
from vllm.model_executor.layers.layernorm import RMSNorm | |
from vllm.model_executor.layers.linear import (LinearMethodBase, | |
MergedColumnParallelLinear, | |
QKVParallelLinear, | |
RowParallelLinear) | |
from vllm.model_executor.layers.rotary_embedding import get_rope | |
from vllm.model_executor.layers.sampler import Sampler | |
from vllm.model_executor.layers.vocab_parallel_embedding import ( | |
VocabParallelEmbedding, ParallelLMHead) | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_world_size) | |
from vllm.model_executor.sampling_metadata import SamplingMetadata | |
from vllm.model_executor.weight_utils import (default_weight_loader, | |
hf_model_weights_iterator) | |
from vllm.sequence import SamplerOutput | |
from vllm.transformers_utils.configs import ChatGLMConfig | |
KVCache = Tuple[torch.Tensor, torch.Tensor] | |
class GLMAttention(nn.Module): | |
def __init__( | |
self, | |
config, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.hidden_size = config.hidden_size | |
tp_size = get_tensor_model_parallel_world_size() | |
self.total_num_heads = config.num_attention_heads | |
assert self.total_num_heads % tp_size == 0 | |
self.num_heads = self.total_num_heads // tp_size | |
self.multi_query_attention = config.multi_query_attention | |
self.total_num_kv_heads = (config.multi_query_group_num | |
if config.multi_query_attention else | |
config.num_attention_heads) | |
if self.total_num_kv_heads >= tp_size: | |
# Number of KV heads is greater than TP size, so we partition | |
# the KV heads across multiple tensor parallel GPUs. | |
assert self.total_num_kv_heads % tp_size == 0 | |
else: | |
# Number of KV heads is less than TP size, so we replicate | |
# the KV heads across multiple tensor parallel GPUs. | |
assert tp_size % self.total_num_kv_heads == 0 | |
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) | |
self.head_dim = config.hidden_size // self.total_num_heads | |
self.q_size = self.num_heads * self.head_dim | |
self.kv_size = self.num_kv_heads * self.head_dim | |
self.scaling = self.head_dim**-0.5 | |
self.query_key_value = QKVParallelLinear( | |
self.hidden_size, | |
self.head_dim, | |
self.total_num_heads, | |
self.total_num_kv_heads, | |
bias=config.add_bias_linear or config.add_qkv_bias, | |
linear_method=linear_method, | |
) | |
self.dense = RowParallelLinear( | |
self.total_num_heads * self.head_dim, | |
config.hidden_size, | |
bias=config.add_bias_linear, | |
linear_method=linear_method, | |
) | |
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 | |
rope_ratio = getattr(config, "rope_ratio", 1.0) | |
max_positions = getattr(config, "seq_length", 8192) | |
self.rotary_emb = get_rope( | |
self.head_dim, | |
rotary_dim=self.head_dim // 2, | |
max_position=max_positions, | |
base=10000 * rope_ratio, | |
is_neox_style=False, | |
) | |
self.attn = PagedAttention( | |
self.num_heads, | |
self.head_dim, | |
self.scaling, | |
num_kv_heads=self.num_kv_heads, | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
position_ids: torch.Tensor, | |
kv_cache: KVCache, | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
qkv, _ = self.query_key_value(hidden_states) | |
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | |
q, k = self.rotary_emb(position_ids, q, k) | |
key_cache, value_cache = kv_cache | |
context_layer = self.attn( | |
q, | |
k, | |
v, | |
key_cache, | |
value_cache, | |
input_metadata, | |
) | |
attn_output, _ = self.dense(context_layer) | |
return attn_output | |
class GLMMLP(nn.Module): | |
"""MLP. | |
MLP will take the input with h hidden state, project it to 4*h | |
hidden dimension, perform nonlinear transformation, and project the | |
state back into h hidden dimension. | |
""" | |
def __init__( | |
self, | |
config, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.add_bias = config.add_bias_linear | |
# Project to 4h. | |
self.dense_h_to_4h = MergedColumnParallelLinear( | |
config.hidden_size, | |
[config.ffn_hidden_size] * 2, | |
bias=config.add_bias_linear, | |
linear_method=linear_method, | |
) | |
self.activation_func = SiluAndMul() | |
# Project back to h. | |
self.dense_4h_to_h = RowParallelLinear( | |
config.ffn_hidden_size, | |
config.hidden_size, | |
bias=config.add_bias_linear, | |
linear_method=linear_method, | |
) | |
def forward(self, hidden_states): | |
# [s, b, 4hp] | |
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) | |
intermediate_parallel = self.activation_func(intermediate_parallel) | |
# [s, b, h] | |
output, _ = self.dense_4h_to_h(intermediate_parallel) | |
return output | |
class GLMBlock(nn.Module): | |
"""A single transformer layer. | |
Transformer layer takes input with size [s, b, h] and returns an | |
output of the same size. | |
""" | |
def __init__( | |
self, | |
config, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.apply_residual_connection_post_layernorm = ( | |
config.apply_residual_connection_post_layernorm) | |
self.fp32_residual_connection = config.fp32_residual_connection | |
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm | |
# Layernorm on the input data. | |
self.input_layernorm = layer_norm_func(config.hidden_size, | |
eps=config.layernorm_epsilon) | |
# Self attention. | |
self.self_attention = GLMAttention(config, linear_method) | |
self.hidden_dropout = config.hidden_dropout | |
# Layernorm on the attention output | |
self.post_attention_layernorm = layer_norm_func( | |
config.hidden_size, eps=config.layernorm_epsilon) | |
# MLP | |
self.mlp = GLMMLP(config, linear_method) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
position_ids: torch.Tensor, | |
kv_cache: KVCache, | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
# hidden_states: [num_tokens, h] | |
# Layer norm at the beginning of the transformer layer. | |
layernorm_output = self.input_layernorm(hidden_states) | |
# Self attention. | |
attention_output = self.self_attention( | |
hidden_states=layernorm_output, | |
position_ids=position_ids, | |
kv_cache=kv_cache, | |
input_metadata=input_metadata, | |
) | |
# Residual connection. | |
if self.apply_residual_connection_post_layernorm: | |
residual = layernorm_output | |
else: | |
residual = hidden_states | |
layernorm_input = residual + attention_output | |
# Layer norm post the self attention. | |
layernorm_output = self.post_attention_layernorm(layernorm_input) | |
# Second residual connection. | |
if self.apply_residual_connection_post_layernorm: | |
residual = layernorm_output | |
else: | |
residual = layernorm_input | |
output = self.mlp(layernorm_output) + residual | |
return output | |
class GLMTransformer(nn.Module): | |
"""Transformer class.""" | |
def __init__( | |
self, | |
config, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.post_layer_norm = config.post_layer_norm | |
# Number of layers. | |
self.num_layers = config.num_layers | |
# Transformer layers. | |
self.layers = nn.ModuleList( | |
[GLMBlock(config, linear_method) for i in range(self.num_layers)]) | |
if self.post_layer_norm: | |
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm | |
# Final layer norm before output. | |
self.final_layernorm = layer_norm_func( | |
config.hidden_size, eps=config.layernorm_epsilon) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
position_ids: torch.Tensor, | |
kv_caches: List[KVCache], | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
for i in range(self.num_layers): | |
layer = self.layers[i] | |
hidden_states = layer( | |
hidden_states=hidden_states, | |
position_ids=position_ids, | |
kv_cache=kv_caches[i], | |
input_metadata=input_metadata, | |
) | |
# Final layer norm. | |
if self.post_layer_norm: | |
hidden_states = self.final_layernorm(hidden_states) | |
return hidden_states | |
class ChatGLMModel(nn.Module): | |
def __init__( | |
self, | |
config, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.embedding = VocabParallelEmbedding(config.padded_vocab_size, | |
config.hidden_size) | |
self.num_layers = config.num_layers | |
self.multi_query_group_num = config.multi_query_group_num | |
self.kv_channels = config.kv_channels | |
self.encoder = GLMTransformer(config, linear_method) | |
self.output_layer = ParallelLMHead(config.padded_vocab_size, | |
config.hidden_size) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
position_ids: torch.Tensor, | |
kv_caches: List[KVCache], | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
inputs_embeds = self.embedding(input_ids) | |
# Run encoder. | |
hidden_states = self.encoder( | |
hidden_states=inputs_embeds, | |
position_ids=position_ids, | |
kv_caches=kv_caches, | |
input_metadata=input_metadata, | |
) | |
return hidden_states | |
class ChatGLMForCausalLM(nn.Module): | |
def __init__( | |
self, | |
config: ChatGLMConfig, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.config: ChatGLMConfig = config | |
self.linear_method = linear_method | |
self.transformer = ChatGLMModel(config, linear_method) | |
self.lm_head_weight = self.transformer.output_layer.weight | |
self.sampler = Sampler(config.padded_vocab_size) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: List[KVCache], | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
hidden_states = self.transformer(input_ids, positions, kv_caches, | |
input_metadata) | |
return hidden_states | |
def sample( | |
self, | |
hidden_states: torch.Tensor, | |
sampling_metadata: SamplingMetadata, | |
) -> Optional[SamplerOutput]: | |
next_tokens = self.sampler(self.lm_head_weight, hidden_states, | |
sampling_metadata) | |
return next_tokens | |
def load_weights(self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
load_format: str = "auto", | |
revision: Optional[str] = None): | |
params_dict = dict(self.named_parameters(remove_duplicate=False)) | |
for name, loaded_weight in hf_model_weights_iterator( | |
model_name_or_path, cache_dir, load_format, revision): | |
if "rotary_pos_emb.inv_freq" in name: | |
continue | |
if "word_embeddings" in name: | |
name = name.replace(".word_embeddings", "") | |
# Skip loading extra bias for GPTQ models. | |
if name.endswith(".bias") and name not in params_dict: | |
continue | |
param = params_dict[name] | |
weight_loader = getattr(param, "weight_loader", | |
default_weight_loader) | |
weight_loader(param, loaded_weight) | |