Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Adapted from | |
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py | |
# Copyright 2023 The vLLM team. | |
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. | |
# | |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
# and OPT implementations in this library. It has been modified from its | |
# original forms to accommodate minor architectural differences compared | |
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Inference-only Mixtral model.""" | |
from typing import List, Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from transformers import MixtralConfig | |
from vllm.model_executor.input_metadata import InputMetadata | |
from vllm.model_executor.layers.attention import PagedAttention | |
from vllm.model_executor.layers.fused_moe import fused_moe | |
from vllm.model_executor.layers.layernorm import RMSNorm | |
from vllm.model_executor.layers.linear import (LinearMethodBase, | |
QKVParallelLinear, | |
ReplicatedLinear, | |
RowParallelLinear) | |
from vllm.model_executor.layers.rotary_embedding import get_rope | |
from vllm.model_executor.layers.sampler import Sampler | |
from vllm.model_executor.layers.vocab_parallel_embedding import ( | |
VocabParallelEmbedding, ParallelLMHead) | |
from vllm.model_executor.parallel_utils.communication_op import ( | |
tensor_model_parallel_all_reduce) | |
from vllm.model_executor.parallel_utils.parallel_state import ( | |
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) | |
from vllm.model_executor.sampling_metadata import SamplingMetadata | |
from vllm.model_executor.utils import set_weight_attrs | |
from vllm.model_executor.weight_utils import (default_weight_loader, | |
hf_model_weights_iterator) | |
from vllm.sequence import SamplerOutput | |
KVCache = Tuple[torch.Tensor, torch.Tensor] | |
class MixtralMoE(nn.Module): | |
"""A tensor-parallel MoE implementation for Mixtral that shards each expert | |
across all ranks. | |
Each expert's weights are sharded across all ranks and a fused MoE | |
kernel is used for the forward pass, and finally we reduce the outputs | |
across ranks. | |
""" | |
def __init__( | |
self, | |
num_experts: int, | |
top_k: int, | |
hidden_size: int, | |
intermediate_size: int, | |
params_dtype: Optional[torch.dtype] = None, | |
): | |
super().__init__() | |
tp_size = get_tensor_model_parallel_world_size() | |
self.num_total_experts = num_experts | |
self.top_k = top_k | |
self.hidden_size = hidden_size | |
self.intermediate_size = intermediate_size // tp_size | |
if params_dtype is None: | |
params_dtype = torch.get_default_dtype() | |
self.params_dtype = params_dtype | |
self.gate = ReplicatedLinear(self.hidden_size, | |
self.num_total_experts, | |
bias=False, | |
params_dtype=self.params_dtype, | |
linear_method=None) | |
self.ws = nn.Parameter( | |
torch.empty(self.num_total_experts, | |
2 * self.intermediate_size, | |
self.hidden_size, | |
device="cuda", | |
dtype=self.params_dtype)) | |
self.w2s = nn.Parameter( | |
torch.empty(self.num_total_experts, | |
self.hidden_size, | |
self.intermediate_size, | |
device="cuda", | |
dtype=self.params_dtype)) | |
set_weight_attrs(self.ws, { | |
"weight_loader": self.weight_loader, | |
}) | |
set_weight_attrs(self.w2s, { | |
"weight_loader": self.weight_loader, | |
}) | |
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, | |
weight_name: str, expert_id: int): | |
tp_rank = get_tensor_model_parallel_rank() | |
param_data = param.data | |
shard_size = self.intermediate_size | |
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) | |
if weight_name.endswith("w1.weight"): | |
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] | |
if weight_name.endswith("w3.weight"): | |
param_data[expert_id, | |
shard_size:2 * shard_size, :] = loaded_weight[shard, :] | |
if weight_name.endswith("w2.weight"): | |
param_data[expert_id, :, :] = loaded_weight[:, shard] | |
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
batch_size, sequence_length, hidden_size = hidden_states.shape | |
hidden_states = hidden_states.view(-1, self.hidden_size) | |
# router_logits: (batch * sequence_length, n_experts) | |
router_logits, _ = self.gate(hidden_states) | |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | |
routing_weights, selected_experts = torch.topk(routing_weights, | |
self.top_k, | |
dim=-1) | |
routing_weights /= routing_weights.sum(dim=-1, keepdim=True) | |
final_hidden_states = fused_moe(hidden_states, | |
self.ws, | |
self.w2s, | |
routing_weights, | |
selected_experts, | |
inplace=True) | |
final_hidden_states = tensor_model_parallel_all_reduce( | |
final_hidden_states) | |
return final_hidden_states.view(batch_size, sequence_length, | |
hidden_size) | |
class MixtralAttention(nn.Module): | |
def __init__(self, | |
hidden_size: int, | |
num_heads: int, | |
num_kv_heads: int, | |
max_position: int = 4096 * 32, | |
rope_theta: float = 10000, | |
linear_method: Optional[LinearMethodBase] = None, | |
sliding_window: Optional[int] = None) -> None: | |
super().__init__() | |
self.hidden_size = hidden_size | |
tp_size = get_tensor_model_parallel_world_size() | |
self.total_num_heads = num_heads | |
assert self.total_num_heads % tp_size == 0 | |
self.num_heads = self.total_num_heads // tp_size | |
self.total_num_kv_heads = num_kv_heads | |
if self.total_num_kv_heads >= tp_size: | |
# Number of KV heads is greater than TP size, so we partition | |
# the KV heads across multiple tensor parallel GPUs. | |
assert self.total_num_kv_heads % tp_size == 0 | |
else: | |
# Number of KV heads is less than TP size, so we replicate | |
# the KV heads across multiple tensor parallel GPUs. | |
assert tp_size % self.total_num_kv_heads == 0 | |
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) | |
self.head_dim = hidden_size // self.total_num_heads | |
self.q_size = self.num_heads * self.head_dim | |
self.kv_size = self.num_kv_heads * self.head_dim | |
self.scaling = self.head_dim**-0.5 | |
self.rope_theta = rope_theta | |
self.sliding_window = sliding_window | |
self.qkv_proj = QKVParallelLinear( | |
hidden_size, | |
self.head_dim, | |
self.total_num_heads, | |
self.total_num_kv_heads, | |
bias=False, | |
linear_method=linear_method, | |
) | |
self.o_proj = RowParallelLinear( | |
self.total_num_heads * self.head_dim, | |
hidden_size, | |
bias=False, | |
linear_method=linear_method, | |
) | |
self.rotary_emb = get_rope( | |
self.head_dim, | |
rotary_dim=self.head_dim, | |
max_position=max_position, | |
base=int(self.rope_theta), | |
is_neox_style=True, | |
) | |
self.attn = PagedAttention( | |
self.num_heads, | |
self.head_dim, | |
self.scaling, | |
num_kv_heads=self.num_kv_heads, | |
sliding_window=self.sliding_window, | |
) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
kv_cache: KVCache, | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
qkv, _ = self.qkv_proj(hidden_states) | |
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | |
q, k = self.rotary_emb(positions, q, k) | |
k_cache, v_cache = kv_cache | |
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) | |
output, _ = self.o_proj(attn_output) | |
return output | |
class MixtralDecoderLayer(nn.Module): | |
def __init__( | |
self, | |
config: MixtralConfig, | |
linear_method: Optional[LinearMethodBase] = None, | |
) -> None: | |
super().__init__() | |
self.hidden_size = config.hidden_size | |
# Requires transformers > 4.32.0 | |
rope_theta = getattr(config, "rope_theta", 10000) | |
self.self_attn = MixtralAttention( | |
hidden_size=self.hidden_size, | |
num_heads=config.num_attention_heads, | |
max_position=config.max_position_embeddings, | |
num_kv_heads=config.num_key_value_heads, | |
rope_theta=rope_theta, | |
sliding_window=config.sliding_window, | |
linear_method=linear_method) | |
self.block_sparse_moe = MixtralMoE( | |
num_experts=config.num_local_experts, | |
top_k=config.num_experts_per_tok, | |
hidden_size=config.hidden_size, | |
intermediate_size=config.intermediate_size) | |
self.input_layernorm = RMSNorm(config.hidden_size, | |
eps=config.rms_norm_eps) | |
self.post_attention_layernorm = RMSNorm(config.hidden_size, | |
eps=config.rms_norm_eps) | |
def forward( | |
self, | |
positions: torch.Tensor, | |
hidden_states: torch.Tensor, | |
kv_cache: KVCache, | |
input_metadata: InputMetadata, | |
residual: Optional[torch.Tensor], | |
) -> torch.Tensor: | |
# Self Attention | |
if residual is None: | |
residual = hidden_states | |
hidden_states = self.input_layernorm(hidden_states) | |
else: | |
hidden_states, residual = self.input_layernorm( | |
hidden_states, residual) | |
hidden_states = self.self_attn( | |
positions=positions, | |
hidden_states=hidden_states, | |
kv_cache=kv_cache, | |
input_metadata=input_metadata, | |
) | |
# Fully Connected | |
hidden_states, residual = self.post_attention_layernorm( | |
hidden_states, residual) | |
hidden_states = self.block_sparse_moe(hidden_states) | |
return hidden_states, residual | |
class MixtralModel(nn.Module): | |
def __init__( | |
self, | |
config: MixtralConfig, | |
linear_method: Optional[LinearMethodBase] = None, | |
) -> None: | |
super().__init__() | |
self.padding_idx = config.pad_token_id | |
self.vocab_size = config.vocab_size | |
self.embed_tokens = VocabParallelEmbedding( | |
config.vocab_size, | |
config.hidden_size, | |
) | |
self.layers = nn.ModuleList([ | |
MixtralDecoderLayer(config, linear_method=linear_method) | |
for _ in range(config.num_hidden_layers) | |
]) | |
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: List[KVCache], | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
hidden_states = self.embed_tokens(input_ids) | |
residual = None | |
for i in range(len(self.layers)): | |
layer = self.layers[i] | |
hidden_states, residual = layer(positions, hidden_states, | |
kv_caches[i], input_metadata, | |
residual) | |
hidden_states, _ = self.norm(hidden_states, residual) | |
return hidden_states | |
class MixtralForCausalLM(nn.Module): | |
def __init__( | |
self, | |
config: MixtralConfig, | |
linear_method: Optional[LinearMethodBase] = None, | |
) -> None: | |
super().__init__() | |
self.config = config | |
self.linear_method = linear_method | |
self.model = MixtralModel(config, linear_method) | |
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) | |
self.sampler = Sampler(config.vocab_size) | |
def forward( | |
self, | |
input_ids: torch.Tensor, | |
positions: torch.Tensor, | |
kv_caches: List[KVCache], | |
input_metadata: InputMetadata, | |
) -> torch.Tensor: | |
hidden_states = self.model(input_ids, positions, kv_caches, | |
input_metadata) | |
return hidden_states | |
def sample( | |
self, | |
hidden_states: Optional[torch.Tensor], | |
sampling_metadata: SamplingMetadata, | |
) -> Optional[SamplerOutput]: | |
next_tokens = self.sampler(self.lm_head.weight, hidden_states, | |
sampling_metadata) | |
return next_tokens | |
def load_weights(self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
load_format: str = "auto", | |
revision: Optional[str] = None): | |
stacked_params_mapping = [ | |
# (param_name, shard_name, shard_id) | |
("qkv_proj", "q_proj", "q"), | |
("qkv_proj", "k_proj", "k"), | |
("qkv_proj", "v_proj", "v"), | |
] | |
expert_params_mapping = [ | |
# (param_name, weight_name, expert_id) | |
("ws" if weight_name in ["w1", "w3"] else "w2s", | |
f"experts.{expert_id}.{weight_name}.weight", expert_id) | |
for expert_id in range(self.config.num_local_experts) | |
for weight_name in ["w1", "w2", "w3"] | |
] | |
params_dict = dict(self.named_parameters()) | |
for name, loaded_weight in hf_model_weights_iterator( | |
model_name_or_path, | |
cache_dir, | |
load_format, | |
revision, | |
fall_back_to_pt=False): | |
if "rotary_emb.inv_freq" in name: | |
continue | |
for (param_name, weight_name, shard_id) in stacked_params_mapping: | |
if weight_name not in name: | |
continue | |
name = name.replace(weight_name, param_name) | |
# Skip loading extra bias for GPTQ models. | |
if name.endswith(".bias") and name not in params_dict: | |
continue | |
param = params_dict[name] | |
weight_loader = param.weight_loader | |
weight_loader(param, loaded_weight, shard_id) | |
break | |
else: | |
for param_name, weight_name, expert_id in expert_params_mapping: | |
if weight_name not in name: | |
continue | |
name = name.replace(weight_name, param_name) | |
param = params_dict[name] | |
weight_loader = param.weight_loader | |
weight_loader(param, | |
loaded_weight, | |
weight_name, | |
expert_id=expert_id) | |
break | |
else: | |
# Skip loading extra bias for GPTQ models. | |
if name.endswith(".bias") and name not in params_dict: | |
continue | |
param = params_dict[name] | |
weight_loader = getattr(param, "weight_loader", | |
default_weight_loader) | |
weight_loader(param, loaded_weight) | |