|
import math |
|
from typing import List, Optional, Tuple |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers.cache_utils import Cache |
|
from transformers.activations import ACT2FN |
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS |
|
from transformers.utils import logging |
|
from transformers import LlamaForCausalLM |
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb |
|
from transformers import LlamaConfig |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "LlamaConfig" |
|
|
|
ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) |
|
|
|
class SVDLlamaMLP(nn.Module): |
|
def __init__(self, config: LlamaConfig): |
|
super().__init__() |
|
self.config = config |
|
self.hidden_size = config.hidden_size |
|
self.intermediate_size = config.intermediate_size |
|
self.ratio = config.ratio |
|
self.low_rank = int(self.intermediate_size * self.hidden_size * self.ratio / (self.intermediate_size + self.hidden_size)) |
|
|
|
self.gate_u_proj = nn.Linear(self.low_rank, self.intermediate_size, bias=config.mlp_bias) |
|
self.gate_v_proj = nn.Linear(self.hidden_size, self.low_rank, bias=False) |
|
|
|
self.down_u_proj = nn.Linear(self.low_rank, self.hidden_size, bias=config.mlp_bias) |
|
self.down_v_proj = nn.Linear(self.intermediate_size, self.low_rank, bias=False) |
|
|
|
self.up_u_proj = nn.Linear(self.low_rank, self.intermediate_size, bias=config.mlp_bias) |
|
self.up_v_proj = nn.Linear(self.hidden_size, self.low_rank, bias=False) |
|
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
def forward(self, x): |
|
up = self.up_u_proj(self.up_v_proj(x)) |
|
gate = self.gate_u_proj(self.gate_v_proj(x)) |
|
return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up)) |
|
|
|
class SVDLlamaAttention(nn.Module): |
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
|
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): |
|
super().__init__() |
|
self.config = config |
|
self.layer_idx = layer_idx |
|
if layer_idx is None: |
|
logger.warning_once( |
|
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " |
|
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " |
|
"when creating this class." |
|
) |
|
|
|
self.attention_dropout = config.attention_dropout |
|
self.hidden_size = config.hidden_size |
|
self.num_heads = config.num_attention_heads |
|
self.head_dim = self.hidden_size // self.num_heads |
|
self.num_key_value_heads = config.num_key_value_heads |
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
self.max_position_embeddings = config.max_position_embeddings |
|
self.rope_theta = config.rope_theta |
|
self.is_causal = True |
|
self.ratio = config.ratio |
|
|
|
if (self.head_dim * self.num_heads) != self.hidden_size: |
|
raise ValueError( |
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
|
f" and `num_heads`: {self.num_heads})." |
|
) |
|
|
|
self.q_low_rank = int(self.num_heads * self.head_dim * self.hidden_size * self.ratio / (self.num_heads * self.head_dim + self.hidden_size)) |
|
self.q_u_proj = nn.Linear(self.q_low_rank, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
self.q_v_proj = nn.Linear(self.hidden_size, self.q_low_rank, bias=False) |
|
|
|
self.k_low_rank = int(self.num_key_value_heads * self.head_dim * self.hidden_size * self.ratio / (self.num_key_value_heads * self.head_dim + self.hidden_size)) |
|
self.k_u_proj = nn.Linear(self.k_low_rank, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
self.k_v_proj = nn.Linear(self.hidden_size, self.k_low_rank, bias=False) |
|
|
|
self.v_low_rank = int(self.num_key_value_heads * self.head_dim * self.hidden_size * self.ratio / (self.num_key_value_heads * self.head_dim + self.hidden_size)) |
|
self.v_u_proj = nn.Linear(self.v_low_rank, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
|
self.v_v_proj = nn.Linear(self.hidden_size, self.v_low_rank, bias=False) |
|
|
|
self.o_low_rank = int(self.hidden_size * self.hidden_size * self.ratio / (self.hidden_size + self.hidden_size)) |
|
self.o_u_proj = nn.Linear(self.o_low_rank, self.hidden_size, bias=config.attention_bias) |
|
self.o_v_proj = nn.Linear(self.hidden_size, self.o_low_rank, bias=False) |
|
|
|
|
|
self.rotary_emb = LlamaRotaryEmbedding(config=self.config) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if output_attentions: |
|
|
|
logger.warning_once( |
|
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
) |
|
return super().forward( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_u_proj(self.q_v_proj(hidden_states)) |
|
key_states = self.k_u_proj(self.k_v_proj(hidden_states)) |
|
value_states = self.v_u_proj(self.v_v_proj(hidden_states)) |
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
if position_embeddings is None: |
|
logger.warning_once( |
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " |
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " |
|
"removed and `position_embeddings` will be mandatory." |
|
) |
|
cos, sin = self.rotary_emb(value_states, position_ids) |
|
else: |
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
causal_mask = attention_mask |
|
if attention_mask is not None: |
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] |
|
|
|
|
|
|
|
if query_states.device.type == "cuda" and causal_mask is not None: |
|
query_states = query_states.contiguous() |
|
key_states = key_states.contiguous() |
|
value_states = value_states.contiguous() |
|
|
|
|
|
|
|
is_causal = True if causal_mask is None and q_len > 1 else False |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attn_mask=causal_mask, |
|
dropout_p=self.attention_dropout if self.training else 0.0, |
|
is_causal=is_causal, |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.view(bsz, q_len, -1) |
|
|
|
attn_output = self.o_u_proj(self.o_v_proj(attn_output)) |
|
|
|
return attn_output, None, past_key_value |
|
|
|
class SVDLLaMASDPA(SVDLlamaAttention): |
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
if output_attentions: |
|
|
|
logger.warning_once( |
|
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " |
|
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
) |
|
return super().forward( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
bsz, q_len, _ = hidden_states.size() |
|
|
|
query_states = self.q_u_proj(self.q_v_proj(hidden_states)) |
|
key_states = self.k_u_proj(self.k_v_proj(hidden_states)) |
|
value_states = self.v_u_proj(self.v_v_proj(hidden_states)) |
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
|
if position_embeddings is None: |
|
logger.warning_once( |
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally " |
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " |
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " |
|
"removed and `position_embeddings` will be mandatory." |
|
) |
|
cos, sin = self.rotary_emb(value_states, position_ids) |
|
else: |
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
|
|
|
if past_key_value is not None: |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
|
causal_mask = attention_mask |
|
if attention_mask is not None: |
|
causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] |
|
|
|
|
|
|
|
if query_states.device.type == "cuda" and causal_mask is not None: |
|
query_states = query_states.contiguous() |
|
key_states = key_states.contiguous() |
|
value_states = value_states.contiguous() |
|
|
|
|
|
|
|
is_causal = True if causal_mask is None and q_len > 1 else False |
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attn_mask=causal_mask, |
|
dropout_p=self.attention_dropout if self.training else 0.0, |
|
is_causal=is_causal, |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.view(bsz, q_len, -1) |
|
|
|
attn_output = self.o_u_proj(self.o_v_proj(attn_output)) |
|
|
|
return attn_output, None, past_key_value |
|
|
|
|
|
class SVDLlamaDecoderLayer(LlamaDecoderLayer): |
|
def __init__(self, config: LlamaConfig, layer_idx: int): |
|
super().__init__(config, layer_idx) |
|
self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx) |
|
self.mlp = SVDLlamaMLP(config) |
|
|
|
|
|
class SVDLlamaForCausalLM(LlamaForCausalLM): |
|
def __init__(self, config: LlamaConfig): |
|
super().__init__(config) |
|
self.model = LlamaModel(config) |
|
self.model.layers = nn.ModuleList( |
|
[SVDLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
|
) |
|
self.model._no_split_modules = ["SVDLlamaDecoderLayer"] |
|
self._no_split_modules = ["SVDLlamaDecoderLayer"] |
|
self.vocab_size = config.vocab_size |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |