Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Adapted from | |
# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py | |
# Copyright (c) Alibaba Cloud. | |
# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE | |
"""Inference-only QWen model compatible with HuggingFace weights.""" | |
from typing import Any, Dict, List, Optional, Tuple | |
import torch | |
from torch import nn | |
from vllm.model_executor.input_metadata import InputMetadata | |
from vllm.model_executor.layers.activation import SiluAndMul | |
from vllm.model_executor.layers.attention import PagedAttention | |
from vllm.model_executor.layers.layernorm import RMSNorm | |
from vllm.model_executor.layers.linear import (LinearMethodBase, | |
MergedColumnParallelLinear, | |
QKVParallelLinear, | |
RowParallelLinear) | |
from vllm.model_executor.layers.rotary_embedding import get_rope | |
from vllm.model_executor.layers.sampler import Sampler | |
from vllm.model_executor.layers.vocab_parallel_embedding import ( | |
VocabParallelEmbedding, ParallelLMHead) | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_world_size) | |
from vllm.model_executor.sampling_metadata import SamplingMetadata | |
from vllm.model_executor.weight_utils import (default_weight_loader, | |
hf_model_weights_iterator) | |
from vllm.sequence import SamplerOutput | |
from vllm.transformers_utils.configs.qwen import QWenConfig | |
KVCache = Tuple[torch.Tensor, torch.Tensor] | |
class QWenMLP(nn.Module): | |
def __init__( | |
self, | |
hidden_size: int, | |
intermediate_size: int, | |
hidden_act: str = "silu", | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.gate_up_proj = MergedColumnParallelLinear( | |
hidden_size, [intermediate_size] * 2, | |
bias=False, | |
linear_method=linear_method) | |
self.c_proj = RowParallelLinear(intermediate_size, | |
hidden_size, | |
bias=False, | |
linear_method=linear_method) | |
if hidden_act != "silu": | |
raise ValueError(f"Unsupported activation: {hidden_act}. " | |
"Only silu is supported for now.") | |
self.act_fn = SiluAndMul() | |
def forward(self, x): | |
gate_up, _ = self.gate_up_proj(x) | |
x = self.act_fn(gate_up) | |
x, _ = self.c_proj(x) | |
return x | |
class QWenAttention(nn.Module): | |
def __init__( | |
self, | |
hidden_size: int, | |
num_heads: int, | |
max_position_embeddings: int, | |
rope_theta: float = 10000, | |
rope_scaling: Optional[Dict[str, Any]] = None, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.hidden_size = hidden_size | |
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( | |
) | |
self.total_num_heads = num_heads | |
assert self.total_num_heads % tensor_model_parallel_world_size == 0 | |
self.num_heads = (self.total_num_heads // | |
tensor_model_parallel_world_size) | |
self.head_dim = hidden_size // self.total_num_heads | |
self.c_attn = QKVParallelLinear( | |
hidden_size, | |
self.head_dim, | |
self.total_num_heads, | |
bias=True, | |
linear_method=linear_method, | |
) | |
self.c_proj = RowParallelLinear( | |
self.total_num_heads * self.head_dim, | |
hidden_size, | |
bias=False, | |
linear_method=linear_method, | |
) | |
self.scaling = self.head_dim**-0.5 | |
self.rotary_emb = get_rope( | |
self.head_dim, | |
rotary_dim=self.head_dim, | |
max_position=max_position_embeddings, | |
base=rope_theta, | |
rope_scaling=rope_scaling, | |
) | |
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
kv_cache: KVCache, | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
qkv, _ = self.c_attn(hidden_states) | |
q, k, v = qkv.chunk(chunks=3, dim=-1) | |
q, k = self.rotary_emb(positions, q, k) | |
k_cache, v_cache = kv_cache | |
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) | |
output, _ = self.c_proj(attn_output) | |
return output | |
class QWenBlock(nn.Module): | |
def __init__( | |
self, | |
config: QWenConfig, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) | |
rope_theta = getattr(config, "rope_theta", 10000) | |
rope_scaling = getattr(config, "rope_scaling", None) | |
self.attn = QWenAttention(config.hidden_size, | |
config.num_attention_heads, | |
config.max_position_embeddings, | |
rope_theta=rope_theta, | |
rope_scaling=rope_scaling, | |
linear_method=linear_method) | |
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) | |
self.mlp = QWenMLP(config.hidden_size, | |
config.intermediate_size // 2, | |
linear_method=linear_method) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
kv_cache: KVCache, | |
input_metadata: InputMetadata, | |
residual: Optional[torch.Tensor], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# Self Attention | |
if residual is None: | |
residual = hidden_states | |
hidden_states = self.ln_1(hidden_states) | |
else: | |
hidden_states, residual = self.ln_1(hidden_states, residual) | |
hidden_states = self.attn( | |
positions=positions, | |
hidden_states=hidden_states, | |
kv_cache=kv_cache, | |
input_metadata=input_metadata, | |
) | |
# Fully Connected | |
hidden_states, residual = self.ln_2(hidden_states, residual) | |
hidden_states = self.mlp(hidden_states) | |
return hidden_states, residual | |
class QWenModel(nn.Module): | |
def __init__( | |
self, | |
config: QWenConfig, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.config = config | |
self.vocab_size = config.vocab_size | |
self.wte = VocabParallelEmbedding( | |
config.vocab_size, | |
config.hidden_size, | |
) | |
self.h = nn.ModuleList([ | |
QWenBlock(config, linear_method) | |
for _ in range(config.num_hidden_layers) | |
]) | |
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: List[KVCache], | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
hidden_states = self.wte(input_ids) | |
residual = None | |
for i in range(len(self.h)): | |
layer = self.h[i] | |
hidden_states, residual = layer( | |
positions, | |
hidden_states, | |
kv_caches[i], | |
input_metadata, | |
residual, | |
) | |
hidden_states, _ = self.ln_f(hidden_states, residual) | |
return hidden_states | |
class QWenLMHeadModel(nn.Module): | |
def __init__( | |
self, | |
config: QWenConfig, | |
linear_method: Optional[LinearMethodBase] = None, | |
): | |
super().__init__() | |
self.config = config | |
self.linear_method = linear_method | |
self.transformer = QWenModel(config, linear_method) | |
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) | |
self.sampler = Sampler(config.vocab_size) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: List[KVCache], | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
hidden_states = self.transformer(input_ids, positions, kv_caches, | |
input_metadata) | |
return hidden_states | |
def sample( | |
self, | |
hidden_states: torch.Tensor, | |
sampling_metadata: SamplingMetadata, | |
) -> Optional[SamplerOutput]: | |
next_tokens = self.sampler(self.lm_head.weight, hidden_states, | |
sampling_metadata) | |
return next_tokens | |
def load_weights(self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
load_format: str = "auto", | |
revision: Optional[str] = None): | |
stacked_params_mapping = [ | |
# (param_name, shard_name, shard_id) | |
("gate_up_proj", "w2", 0), | |
("gate_up_proj", "w1", 1), | |
] | |
params_dict = dict(self.named_parameters()) | |
for name, loaded_weight in hf_model_weights_iterator( | |
model_name_or_path, cache_dir, load_format, revision): | |
if "rotary_emb.inv_freq" in name: | |
continue | |
for (param_name, weight_name, shard_id) in stacked_params_mapping: | |
if weight_name not in name: | |
continue | |
name = name.replace(weight_name, param_name) | |
# Skip loading extra bias for GPTQ models. | |
if name.endswith(".bias") and name not in params_dict: | |
continue | |
param = params_dict[name] | |
weight_loader = param.weight_loader | |
weight_loader(param, loaded_weight, shard_id) | |
break | |
else: | |
# Skip loading extra bias for GPTQ models. | |
if name.endswith(".bias") and name not in params_dict: | |
continue | |
param = params_dict[name] | |
weight_loader = getattr(param, "weight_loader", | |
default_weight_loader) | |
weight_loader(param, loaded_weight) | |