# Copyright (c) OpenMMLab. All rights reserved. import inspect import warnings from typing import Optional import torch import torch.distributed as dist import torch.nn as nn from mmengine import MessageHub from transformers.cache_utils import Cache from transformers.models.mistral.modeling_mistral import (apply_rotary_pos_emb, repeat_kv) from xtuner.parallel.sequence import get_sequence_parallel_world_size from xtuner.parallel.sequence.attention import ( post_process_for_sequence_parallel_attn, pre_process_for_sequence_parallel_attn) from .attention import flash_attn_wo_mask, varlen_flash_attn from .triton_kernels import apply_rotary_emb SUPPORT_FLASH2 = False try: from flash_attn import flash_attn_func _flash_supports_window_size = 'window_size' in list( inspect.signature(flash_attn_func).parameters) SUPPORT_FLASH2 = True except ImportError: pass class MistralRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.inv_freq = 1.0 / ( base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim)) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(device)) # Different from paper, but it uses a different permutation # in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(device) self.cos_cached = emb.cos().to(dtype) self.sin_cached = emb.sin().to(dtype) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if (seq_len > self.max_seq_len_cached or self.cos_cached.device != x.device # noqa: W503 or self.cos_cached.dtype != x.dtype): # noqa: W503 self._set_cos_sin_cache( seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), ) def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)""" batch, slen, num_key_value_heads, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim) return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) def mistral_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ): if 'padding_mask' in kwargs: warnings.warn( 'Passing `padding_mask` is deprecated and will be removed in ' 'v4.37. Please make sure use `attention_mask` instead.`') # overwrite attention_mask with padding_mask attention_mask = kwargs.pop('padding_mask') bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( 'The cache structure has changed since version v4.36. ' f'If you are using {self.__class__.__name__} ' 'for auto-regressive decoding with k/v caching, ' 'please make sure to initialize the attention class ' 'with a layer index.') kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) assert position_ids is not None if self.training: cos, sin = self.rotary_emb( value_states, seq_len=position_ids.max() + 1) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) use_sliding_windows = ( _flash_supports_window_size and getattr(self.config, 'sliding_window', None) is not None and kv_seq_len > self.config.sliding_window) if past_key_value is not None: # Activate slicing cache only if the config has a value # `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 if (getattr(self.config, 'sliding_window', None) is not None and kv_seq_len > self.config.sliding_window and cache_has_contents): slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[self.layer_idx][0] past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( 'past key must have a shape of (`batch_size, num_heads, ' 'self.config.sliding_window-1, head_dim`), got' f' {past_key.shape}') if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat( [attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads for sequence parallel key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for training # stability reasons therefore the input hidden states gets silently # casted in float32. Hence, we need cast them back in the correct dtype # just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, '_pre_quantization_dtype'): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # Reashape to the expected shape for Flash Attention query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) enable_sequence_parallel = ( dist.is_initialized() and get_sequence_parallel_world_size() > 1 and self.training) if enable_sequence_parallel: query_states, key_states, value_states = \ pre_process_for_sequence_parallel_attn( query_states, key_states, value_states) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, query_length=query_states.shape[1], dropout=dropout_rate, use_sliding_windows=use_sliding_windows, ) if enable_sequence_parallel: attn_output = post_process_for_sequence_parallel_attn(attn_output) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value def mistral_varlen_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, ): is_training = self.training message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (past_key_value is None) use_varlen_atten = (cumulative_len is not None) if 'padding_mask' in kwargs: warnings.warn( 'Passing `padding_mask` is deprecated and will be removed in v4.37' ' Please make sure use `attention_mask` instead.`') # overwrite attention_mask with padding_mask attention_mask = kwargs.pop('padding_mask') bsz, q_len, _ = hidden_states.size() assert bsz == 1, (f'If utilizing local attention, the batch size should be' f' set to 1, but got {bsz}') # attention_mask is set to None if no padding token in input_ids assert attention_mask is None query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) assert _flash_supports_window_size, \ ('The current flash attention version does not support sliding window ' 'attention, for a more memory efficient implementation make sure ' 'to upgrade flash-attn library.') kv_seq_len = key_states.shape[-3] if past_key_value is not None: if self.layer_idx is None: raise ValueError( 'The cache structure has changed since version v4.36. ' f'If you are using {self.__class__.__name__} ' 'for auto-regressive decoding with k/v caching, ' 'please make sure to initialize the attention class ' 'with a layer index.') kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if use_varlen_atten: cos, sin = self.rotary_emb(value_states, max_seqlen) query_states = apply_rotary_emb(query_states, cos[position_ids].squeeze(0), sin[position_ids].squeeze(0)) key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), sin[position_ids].squeeze(0)) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) # Because the input can be padded, the absolute sequence length # depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids) # Activate slicing cache only if the config has a value # `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 if (getattr(self.config, 'sliding_window', None) is not None and kv_seq_len > self.config.sliding_window # noqa: W503 and cache_has_contents): # noqa: W503 slicing_tokens = 1 - self.config.sliding_window past_key = past_key_value[self.layer_idx][0] past_value = past_key_value[self.layer_idx][1] past_key = past_key[:, :, slicing_tokens:, :].contiguous() past_value = past_value[:, :, slicing_tokens:, :].contiguous() if past_key.shape[-2] != self.config.sliding_window - 1: raise ValueError( 'past key must have a shape of (`batch_size, num_heads, ' 'self.config.sliding_window-1, head_dim`), got' f' {past_key.shape}') if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] attention_mask = torch.cat( [attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) # repeat kv for sequence parallel key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) dropout_rate = 0.0 if not self.training else self.attention_dropout # In PEFT, usually we cast the layer norms in float32 for # training stability reasons, therefore the input hidden states gets # silently casted in float32. Hence, we need # cast them back in float16 just to be sure everything works as expected. input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, '_pre_quantization_dtype'): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) # ----------------- flash attention forward ------------------------# if not self._flash_attn_uses_top_left_mask: causal = self.is_causal else: causal = self.is_causal and q_len != 1 use_sliding_windows = ( _flash_supports_window_size and # noqa: W504 getattr(self.config, 'sliding_window', None) is not None # noqa: W503 and kv_seq_len > self.config.sliding_window) # noqa: W503 window_size = (self.config.sliding_window, self.config.sliding_window) if use_sliding_windows else (-1, -1) if use_varlen_atten: attn_output = varlen_flash_attn( query_states, key_states, value_states, cumulative_len, max_seqlen, causal=causal, dropout_p=dropout_rate, window_size=window_size, training=self.training) else: attn_output = flash_attn_wo_mask( query_states, key_states, value_states, causal=causal, dropout_p=dropout_rate, window_size=window_size, training=self.training) # ---------------- flash attention forward end ------------------- # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value