import math from typing import List, Optional, Tuple import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from transformers.cache_utils import Cache from transformers.activations import ACT2FN from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.utils import logging from transformers import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel, LlamaRotaryEmbedding, LlamaRMSNorm, repeat_kv, apply_rotary_pos_emb from transformers import LlamaConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LlamaConfig" ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) class SVDLlamaMLP(nn.Module): def __init__(self, config: LlamaConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.ratio = config.ratio self.low_rank = int(self.intermediate_size * self.hidden_size * self.ratio / (self.intermediate_size + self.hidden_size)) self.gate_u_proj = nn.Linear(self.low_rank, self.intermediate_size, bias=config.mlp_bias) self.gate_v_proj = nn.Linear(self.hidden_size, self.low_rank, bias=False) self.down_u_proj = nn.Linear(self.low_rank, self.hidden_size, bias=config.mlp_bias) self.down_v_proj = nn.Linear(self.intermediate_size, self.low_rank, bias=False) self.up_u_proj = nn.Linear(self.low_rank, self.intermediate_size, bias=config.mlp_bias) self.up_v_proj = nn.Linear(self.hidden_size, self.low_rank, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): up = self.up_u_proj(self.up_v_proj(x)) gate = self.gate_u_proj(self.gate_v_proj(x)) return self.down_u_proj(self.down_v_proj(self.act_fn(gate) * up)) class SVDLlamaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.ratio = config.ratio # 1 means no truncate, just keep normal attn if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_low_rank = int(self.num_heads * self.head_dim * self.hidden_size * self.ratio / (self.num_heads * self.head_dim + self.hidden_size)) self.q_u_proj = nn.Linear(self.q_low_rank, self.num_heads * self.head_dim, bias=config.attention_bias) self.q_v_proj = nn.Linear(self.hidden_size, self.q_low_rank, bias=False) self.k_low_rank = int(self.num_key_value_heads * self.head_dim * self.hidden_size * self.ratio / (self.num_key_value_heads * self.head_dim + self.hidden_size)) self.k_u_proj = nn.Linear(self.k_low_rank, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.k_v_proj = nn.Linear(self.hidden_size, self.k_low_rank, bias=False) self.v_low_rank = int(self.num_key_value_heads * self.head_dim * self.hidden_size * self.ratio / (self.num_key_value_heads * self.head_dim + self.hidden_size)) self.v_u_proj = nn.Linear(self.v_low_rank, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_v_proj = nn.Linear(self.hidden_size, self.v_low_rank, bias=False) self.o_low_rank = int(self.hidden_size * self.hidden_size * self.ratio / (self.hidden_size + self.hidden_size)) self.o_u_proj = nn.Linear(self.o_low_rank, self.hidden_size, bias=config.attention_bias) self.o_v_proj = nn.Linear(self.hidden_size, self.o_low_rank, bias=False) # TODO (joao): remove in v4.45 (RoPE is computed in the model, not in the decoder layers) self.rotary_emb = LlamaRotaryEmbedding(config=self.config) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # bsz, q_len, _ = hidden_states.size() # query_states = self.q_u_proj(self.q_v_proj(hidden_states)) # key_states = self.k_u_proj(self.k_v_proj(hidden_states)) # value_states = self.v_u_proj(self.v_v_proj(hidden_states)) # query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) # key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # if position_embeddings is None: # logger.warning_once( # "The attention layers in this model are transitioning from computing the RoPE embeddings internally " # "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " # "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " # "removed and `position_embeddings` will be mandatory." # ) # cos, sin = self.rotary_emb(value_states, position_ids) # else: # cos, sin = position_embeddings # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # if past_key_value is not None: # # sin and cos are specific to RoPE models; cache_position needed for the static cache # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # key_states = repeat_kv(key_states, self.num_key_value_groups) # value_states = repeat_kv(value_states, self.num_key_value_groups) # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) # if attention_mask is not None: # no matter the length, we just slice it # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] # attn_weights = attn_weights + causal_mask # # upcast attention to fp32 # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) # attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) # attn_output = torch.matmul(attn_weights, value_states) # if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): # raise ValueError( # f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" # f" {attn_output.size()}" # ) # attn_output = attn_output.transpose(1, 2).contiguous() # attn_output = attn_output.reshape(bsz, q_len, -1) # attn_output = self.o_u_proj(self.o_v_proj(attn_output)) # if not output_attentions: # attn_weights = None # return attn_output, attn_weights, past_key_value if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_u_proj(self.q_v_proj(hidden_states)) key_states = self.k_u_proj(self.k_v_proj(hidden_states)) value_states = self.v_u_proj(self.v_v_proj(hidden_states)) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_u_proj(self.o_v_proj(attn_output)) return attn_output, None, past_key_value class SVDLLaMASDPA(SVDLlamaAttention): def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) bsz, q_len, _ = hidden_states.size() query_states = self.q_u_proj(self.q_v_proj(hidden_states)) key_states = self.k_u_proj(self.k_v_proj(hidden_states)) value_states = self.v_u_proj(self.v_v_proj(hidden_states)) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.45 `position_ids` will be " "removed and `position_embeddings` will be mandatory." ) cos, sin = self.rotary_emb(value_states, position_ids) else: cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) causal_mask = attention_mask if attention_mask is not None: causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. if query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. is_causal = True if causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, -1) attn_output = self.o_u_proj(self.o_v_proj(attn_output)) return attn_output, None, past_key_value class SVDLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = SVDLlamaAttention(config=config, layer_idx=layer_idx) self.mlp = SVDLlamaMLP(config) class SVDLlamaForCausalLM(LlamaForCausalLM): def __init__(self, config: LlamaConfig): super().__init__(config) self.model = LlamaModel(config) self.model.layers = nn.ModuleList( [SVDLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) self.model._no_split_modules = ["SVDLlamaDecoderLayer"] self._no_split_modules = ["SVDLlamaDecoderLayer"] self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init()