zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
import warnings
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from mmengine import MessageHub
from transformers.cache_utils import Cache
from transformers.models.mistral.modeling_mistral import (apply_rotary_pos_emb,
repeat_kv)
from xtuner.parallel.sequence import get_sequence_parallel_world_size
from xtuner.parallel.sequence.attention import (
post_process_for_sequence_parallel_attn,
pre_process_for_sequence_parallel_attn)
from .attention import flash_attn_wo_mask, varlen_flash_attn
from .triton_kernels import apply_rotary_emb
SUPPORT_FLASH2 = False
try:
from flash_attn import flash_attn_func
_flash_supports_window_size = 'window_size' in list(
inspect.signature(flash_attn_func).parameters)
SUPPORT_FLASH2 = True
except ImportError:
pass
class MistralRotaryEmbedding(nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (
base**(torch.arange(0, self.dim, 2).float().to(device) / self.dim))
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype())
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(device))
# Different from paper, but it uses a different permutation
# in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(device)
self.cos_cached = emb.cos().to(dtype)
self.sin_cached = emb.sin().to(dtype)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if (seq_len > self.max_seq_len_cached
or self.cos_cached.device != x.device # noqa: W503
or self.cos_cached.dtype != x.dtype): # noqa: W503
self._set_cos_sin_cache(
seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
to (batch, seqlen, num_attention_heads, head_dim)"""
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :,
None, :].expand(batch, slen,
num_key_value_heads, n_rep,
head_dim)
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
head_dim)
def mistral_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in '
'v4.37. Please make sure use `attention_mask` instead.`')
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop('padding_mask')
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
'The cache structure has changed since version v4.36. '
f'If you are using {self.__class__.__name__} '
'for auto-regressive decoding with k/v caching, '
'please make sure to initialize the attention class '
'with a layer index.')
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
assert position_ids is not None
if self.training:
cos, sin = self.rotary_emb(
value_states, seq_len=position_ids.max() + 1)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, 'sliding_window', None) is not None
and kv_seq_len > self.config.sliding_window)
if past_key_value is not None:
# Activate slicing cache only if the config has a value
# `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (getattr(self.config, 'sliding_window', None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
'past key must have a shape of (`batch_size, num_heads, '
'self.config.sliding_window-1, head_dim`), got'
f' {past_key.shape}')
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat(
[attention_mask,
torch.ones_like(attention_mask[:, -1:])],
dim=-1)
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads for sequence parallel
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training
# stability reasons therefore the input hidden states gets silently
# casted in float32. Hence, we need cast them back in the correct dtype
# just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
enable_sequence_parallel = (
dist.is_initialized() and get_sequence_parallel_world_size() > 1
and self.training)
if enable_sequence_parallel:
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
query_length=query_states.shape[1],
dropout=dropout_rate,
use_sliding_windows=use_sliding_windows,
)
if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
attn_output = attn_output.reshape(bsz, q_len,
self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def mistral_varlen_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
):
is_training = self.training
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (past_key_value is None)
use_varlen_atten = (cumulative_len is not None)
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37'
' Please make sure use `attention_mask` instead.`')
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop('padding_mask')
bsz, q_len, _ = hidden_states.size()
assert bsz == 1, (f'If utilizing local attention, the batch size should be'
f' set to 1, but got {bsz}')
# attention_mask is set to None if no padding token in input_ids
assert attention_mask is None
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim)
assert _flash_supports_window_size, \
('The current flash attention version does not support sliding window '
'attention, for a more memory efficient implementation make sure '
'to upgrade flash-attn library.')
kv_seq_len = key_states.shape[-3]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
'The cache structure has changed since version v4.36. '
f'If you are using {self.__class__.__name__} '
'for auto-regressive decoding with k/v caching, '
'please make sure to initialize the attention class '
'with a layer index.')
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
if use_varlen_atten:
cos, sin = self.rotary_emb(value_states, max_seqlen)
query_states = apply_rotary_emb(query_states,
cos[position_ids].squeeze(0),
sin[position_ids].squeeze(0))
key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# Because the input can be padded, the absolute sequence length
# depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item() + 1)
cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
# Activate slicing cache only if the config has a value
# `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (getattr(self.config, 'sliding_window', None) is not None
and kv_seq_len > self.config.sliding_window # noqa: W503
and cache_has_contents): # noqa: W503
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
'past key must have a shape of (`batch_size, num_heads, '
'self.config.sliding_window-1, head_dim`), got'
f' {past_key.shape}')
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat(
[attention_mask,
torch.ones_like(attention_mask[:, -1:])],
dim=-1)
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# repeat kv for sequence parallel
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for
# training stability reasons, therefore the input hidden states gets
# silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# ----------------- flash attention forward ------------------------#
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
causal = self.is_causal and q_len != 1
use_sliding_windows = (
_flash_supports_window_size and # noqa: W504
getattr(self.config, 'sliding_window', None) is not None # noqa: W503
and kv_seq_len > self.config.sliding_window) # noqa: W503
window_size = (self.config.sliding_window,
self.config.sliding_window) if use_sliding_windows else (-1,
-1)
if use_varlen_atten:
attn_output = varlen_flash_attn(
query_states,
key_states,
value_states,
cumulative_len,
max_seqlen,
causal=causal,
dropout_p=dropout_rate,
window_size=window_size,
training=self.training)
else:
attn_output = flash_attn_wo_mask(
query_states,
key_states,
value_states,
causal=causal,
dropout_p=dropout_rate,
window_size=window_size,
training=self.training)
# ---------------- flash attention forward end ------------------- #
attn_output = attn_output.reshape(bsz, q_len,
self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value