|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Inference-only LLaMA model compatible with HuggingFace weights.""" |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
from transformers import LlamaConfig |
|
|
|
from vllm.model_executor.input_metadata import InputMetadata |
|
from vllm.model_executor.layers.activation import SiluAndMul |
|
from vllm.model_executor.layers.attention import PagedAttention |
|
from vllm.model_executor.layers.layernorm import RMSNorm |
|
from vllm.model_executor.layers.linear import ( |
|
LinearMethodBase, |
|
MergedColumnParallelLinear, |
|
QKVParallelLinear, |
|
RowParallelLinear, |
|
) |
|
from vllm.model_executor.layers.rotary_embedding import get_rope |
|
from vllm.model_executor.layers.sampler import Sampler |
|
from vllm.model_executor.layers.vocab_parallel_embedding import ( |
|
VocabParallelEmbedding, |
|
ParallelLMHead, |
|
) |
|
from vllm.model_executor.parallel_utils.parallel_state import ( |
|
get_tensor_model_parallel_world_size, |
|
) |
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
from vllm.model_executor.weight_utils import ( |
|
default_weight_loader, |
|
hf_model_weights_iterator, |
|
) |
|
from vllm.sequence import SamplerOutput |
|
|
|
KVCache = Tuple[torch.Tensor, torch.Tensor] |
|
|
|
|
|
class LlamaMLP(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
intermediate_size: int, |
|
hidden_act: str, |
|
linear_method: Optional[LinearMethodBase] = None, |
|
) -> None: |
|
super().__init__() |
|
self.gate_up_proj = MergedColumnParallelLinear( |
|
hidden_size, |
|
[intermediate_size] * 2, |
|
bias=False, |
|
linear_method=linear_method, |
|
) |
|
self.down_proj = RowParallelLinear( |
|
intermediate_size, hidden_size, bias=False, linear_method=linear_method |
|
) |
|
if hidden_act != "silu": |
|
raise ValueError( |
|
f"Unsupported activation: {hidden_act}. " |
|
"Only silu is supported for now." |
|
) |
|
self.act_fn = SiluAndMul() |
|
|
|
def forward(self, x): |
|
gate_up, _ = self.gate_up_proj(x) |
|
x = self.act_fn(gate_up) |
|
x, _ = self.down_proj(x) |
|
return x |
|
|
|
|
|
class LlamaAttention(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
num_heads: int, |
|
num_kv_heads: int, |
|
rope_theta: float = 10000, |
|
rope_scaling: Optional[Dict[str, Any]] = None, |
|
max_position_embeddings: int = 8192, |
|
linear_method: Optional[LinearMethodBase] = None, |
|
) -> None: |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
tp_size = get_tensor_model_parallel_world_size() |
|
self.total_num_heads = num_heads |
|
assert self.total_num_heads % tp_size == 0 |
|
self.num_heads = self.total_num_heads // tp_size |
|
self.total_num_kv_heads = num_kv_heads |
|
if self.total_num_kv_heads >= tp_size: |
|
|
|
|
|
assert self.total_num_kv_heads % tp_size == 0 |
|
else: |
|
|
|
|
|
assert tp_size % self.total_num_kv_heads == 0 |
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) |
|
self.head_dim = hidden_size // self.total_num_heads |
|
self.q_size = self.num_heads * self.head_dim |
|
self.kv_size = self.num_kv_heads * self.head_dim |
|
self.scaling = self.head_dim**-0.5 |
|
self.rope_theta = rope_theta |
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
self.qkv_proj = QKVParallelLinear( |
|
hidden_size, |
|
self.head_dim, |
|
self.total_num_heads, |
|
self.total_num_kv_heads, |
|
bias=False, |
|
linear_method=linear_method, |
|
) |
|
self.o_proj = RowParallelLinear( |
|
self.total_num_heads * self.head_dim, |
|
hidden_size, |
|
bias=False, |
|
linear_method=linear_method, |
|
) |
|
|
|
self.rotary_emb = get_rope( |
|
self.head_dim, |
|
rotary_dim=self.head_dim, |
|
max_position=max_position_embeddings, |
|
base=rope_theta, |
|
rope_scaling=rope_scaling, |
|
) |
|
self.attn = PagedAttention( |
|
self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads |
|
) |
|
|
|
def forward( |
|
self, |
|
positions: torch.Tensor, |
|
hidden_states: torch.Tensor, |
|
kv_cache: KVCache, |
|
input_metadata: InputMetadata, |
|
) -> torch.Tensor: |
|
qkv, _ = self.qkv_proj(hidden_states) |
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
|
q, k = self.rotary_emb(positions, q, k) |
|
k_cache, v_cache = kv_cache |
|
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) |
|
output, _ = self.o_proj(attn_output) |
|
return output |
|
|
|
|
|
class LlamaDecoderLayer(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
config: LlamaConfig, |
|
linear_method: Optional[LinearMethodBase] = None, |
|
) -> None: |
|
super().__init__() |
|
self.hidden_size = config.hidden_size |
|
rope_theta = getattr(config, "rope_theta", 10000) |
|
rope_scaling = getattr(config, "rope_scaling", None) |
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) |
|
self.self_attn = LlamaAttention( |
|
hidden_size=self.hidden_size, |
|
num_heads=config.num_attention_heads, |
|
num_kv_heads=config.num_key_value_heads, |
|
rope_theta=rope_theta, |
|
rope_scaling=rope_scaling, |
|
max_position_embeddings=max_position_embeddings, |
|
linear_method=linear_method, |
|
) |
|
self.mlp = LlamaMLP( |
|
hidden_size=self.hidden_size, |
|
intermediate_size=config.intermediate_size, |
|
hidden_act=config.hidden_act, |
|
linear_method=linear_method, |
|
) |
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
self.post_attention_layernorm = RMSNorm( |
|
config.hidden_size, eps=config.rms_norm_eps |
|
) |
|
|
|
def forward( |
|
self, |
|
positions: torch.Tensor, |
|
hidden_states: torch.Tensor, |
|
kv_cache: KVCache, |
|
input_metadata: InputMetadata, |
|
residual: Optional[torch.Tensor], |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if residual is None: |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
else: |
|
hidden_states, residual = self.input_layernorm(hidden_states, residual) |
|
hidden_states = self.self_attn( |
|
positions=positions, |
|
hidden_states=hidden_states, |
|
kv_cache=kv_cache, |
|
input_metadata=input_metadata, |
|
) |
|
|
|
|
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) |
|
hidden_states = self.mlp(hidden_states) |
|
return hidden_states, residual |
|
|
|
|
|
class LlamaModel(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
config: LlamaConfig, |
|
linear_method: Optional[LinearMethodBase] = None, |
|
) -> None: |
|
super().__init__() |
|
self.config = config |
|
self.padding_idx = config.pad_token_id |
|
self.vocab_size = config.vocab_size |
|
self.embed_tokens = VocabParallelEmbedding( |
|
config.vocab_size, |
|
config.hidden_size, |
|
) |
|
self.layers = nn.ModuleList( |
|
[ |
|
LlamaDecoderLayer(config, linear_method) |
|
for _ in range(config.num_hidden_layers) |
|
] |
|
) |
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
def forward( |
|
self, |
|
input_emb: torch.Tensor, |
|
positions: torch.Tensor, |
|
kv_caches: List[KVCache], |
|
input_metadata: InputMetadata, |
|
) -> torch.Tensor: |
|
hidden_states = input_emb |
|
residual = None |
|
for i in range(len(self.layers)): |
|
layer = self.layers[i] |
|
hidden_states, residual = layer( |
|
positions, |
|
hidden_states, |
|
kv_caches[i], |
|
input_metadata, |
|
residual, |
|
) |
|
hidden_states, _ = self.norm(hidden_states, residual) |
|
return hidden_states |
|
|
|
def load_weights( |
|
self, |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
load_format: str = "auto", |
|
revision: Optional[str] = None, |
|
): |
|
stacked_params_mapping = [ |
|
|
|
("qkv_proj", "q_proj", "q"), |
|
("qkv_proj", "k_proj", "k"), |
|
("qkv_proj", "v_proj", "v"), |
|
("gate_up_proj", "gate_proj", 0), |
|
("gate_up_proj", "up_proj", 1), |
|
] |
|
params_dict = dict(self.named_parameters()) |
|
for name, loaded_weight in hf_model_weights_iterator( |
|
model_name_or_path, cache_dir, load_format, revision |
|
): |
|
if "rotary_emb.inv_freq" in name: |
|
continue |
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: |
|
|
|
|
|
continue |
|
for param_name, weight_name, shard_id in stacked_params_mapping: |
|
if weight_name not in name: |
|
continue |
|
name = name.replace(weight_name, param_name) |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
param = params_dict[name] |
|
weight_loader = param.weight_loader |
|
weight_loader(param, loaded_weight, shard_id) |
|
break |
|
else: |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
param = params_dict[name] |
|
weight_loader = getattr(param, "weight_loader", default_weight_loader) |
|
weight_loader(param, loaded_weight) |
|
|
|
|
|
class LlamaForCausalLM(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
config: LlamaConfig, |
|
linear_method: Optional[LinearMethodBase] = None, |
|
) -> None: |
|
super().__init__() |
|
self.config = config |
|
self.linear_method = linear_method |
|
self.model = LlamaModel(config, linear_method) |
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) |
|
self.sampler = Sampler(config.vocab_size) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.Tensor, |
|
positions: torch.Tensor, |
|
kv_caches: List[KVCache], |
|
input_metadata: InputMetadata, |
|
) -> torch.Tensor: |
|
hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) |
|
return hidden_states |
|
|
|
def sample( |
|
self, |
|
hidden_states: torch.Tensor, |
|
sampling_metadata: SamplingMetadata, |
|
) -> Optional[SamplerOutput]: |
|
next_tokens = self.sampler( |
|
self.lm_head.weight, hidden_states, sampling_metadata |
|
) |
|
return next_tokens |
|
|
|
def load_weights( |
|
self, |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
load_format: str = "auto", |
|
revision: Optional[str] = None, |
|
): |
|
stacked_params_mapping = [ |
|
|
|
("qkv_proj", "q_proj", "q"), |
|
("qkv_proj", "k_proj", "k"), |
|
("qkv_proj", "v_proj", "v"), |
|
("gate_up_proj", "gate_proj", 0), |
|
("gate_up_proj", "up_proj", 1), |
|
] |
|
params_dict = dict(self.named_parameters()) |
|
for name, loaded_weight in hf_model_weights_iterator( |
|
model_name_or_path, cache_dir, load_format, revision |
|
): |
|
if "rotary_emb.inv_freq" in name: |
|
continue |
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: |
|
|
|
|
|
continue |
|
for param_name, weight_name, shard_id in stacked_params_mapping: |
|
if weight_name not in name: |
|
continue |
|
name = name.replace(weight_name, param_name) |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
param = params_dict[name] |
|
weight_loader = param.weight_loader |
|
weight_loader(param, loaded_weight, shard_id) |
|
break |
|
else: |
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
continue |
|
param = params_dict[name] |
|
weight_loader = getattr(param, "weight_loader", default_weight_loader) |
|
weight_loader(param, loaded_weight) |
|
|