Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Optional | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from mmengine import MessageHub | |
from transformers.cache_utils import Cache | |
from xtuner.model.transformers_models.deepseek_v2.modeling_deepseek import \ | |
apply_rotary_pos_emb | |
from xtuner.parallel.sequence import (get_sequence_parallel_world_size, | |
post_process_for_sequence_parallel_attn, | |
pre_process_for_sequence_parallel_attn) | |
from .attention import flash_attn_wo_mask, varlen_flash_attn | |
def deepseek_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
**kwargs, | |
): | |
# DeepseekV2FlashAttention2 attention does not support output_attentions | |
if 'padding_mask' in kwargs: | |
warnings.warn( | |
'Passing `padding_mask` is deprecated and will be removed in ' | |
'v4.37. Please make sure use `attention_mask` instead.`') | |
# overwrite attention_mask with padding_mask | |
attention_mask = kwargs.pop('padding_mask') | |
output_attentions = False | |
bsz, q_len, _ = hidden_states.size() | |
if self.q_lora_rank is None: | |
q = self.q_proj(hidden_states) | |
else: | |
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) | |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) | |
q_nope, q_pe = torch.split( | |
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) | |
# Flash attention requires the input to have the shape | |
# batch_size x seq_length x head_dim x hidden_dim | |
# therefore we just need to keep the original shape | |
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) | |
compressed_kv, k_pe = torch.split( | |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) | |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) | |
kv = ( | |
self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view( | |
bsz, q_len, self.num_heads, | |
self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)) | |
k_nope, value_states = torch.split( | |
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) | |
kv_seq_len = value_states.shape[-2] | |
kv_seq_len = value_states.shape[-2] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, | |
self.layer_idx) | |
assert position_ids is not None, '`position_ids` should not be None.' | |
if self.training: | |
cos, sin = self.rotary_emb( | |
value_states, seq_len=position_ids.max() + 1) | |
else: | |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) | |
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) | |
query_states[:, :, :, :self.qk_nope_head_dim] = q_nope | |
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe | |
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) | |
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope | |
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe | |
if self.q_head_dim != self.v_head_dim: | |
value_states = F.pad(value_states, | |
[0, self.q_head_dim - self.v_head_dim]) | |
if past_key_value is not None: | |
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs) | |
# Reashape to the expected shape for Flash Attention | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
# In PEFT, usually we cast the layer norms in float32 for training | |
# stability reasons therefore the input hidden states gets silently | |
# casted in float32. Hence, we need cast them back in the correct dtype | |
# just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not | |
# cast the LayerNorms in fp32. (DeepseekV2RMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
# Handle the case where the model is quantized | |
if hasattr(self.config, '_pre_quantization_dtype'): | |
target_dtype = self.config._pre_quantization_dtype | |
elif torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
else: | |
target_dtype = self.q_a_proj.weight.dtype | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
enable_sequence_parallel = ( | |
dist.is_initialized() and get_sequence_parallel_world_size() > 1 | |
and self.training) | |
if enable_sequence_parallel: | |
query_states, key_states, value_states = \ | |
pre_process_for_sequence_parallel_attn( | |
query_states, key_states, value_states) | |
attn_output = self._flash_attention_forward( | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
query_states.shape[1], | |
dropout=dropout_rate, | |
softmax_scale=self.softmax_scale, | |
) | |
if enable_sequence_parallel: | |
attn_output = post_process_for_sequence_parallel_attn(attn_output) | |
if self.q_head_dim != self.v_head_dim: | |
attn_output = attn_output[:, :, :, :self.v_head_dim] | |
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * | |
self.v_head_dim).contiguous() | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
def deepseek_varlen_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
**kwargs, | |
): | |
is_training = self.training | |
message_hub = MessageHub.get_instance('varlen_attn_args') | |
rank = dist.get_rank() | |
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') | |
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') | |
assert is_training == (cumulative_len is not None) == ( | |
past_key_value is None) | |
output_attentions = False | |
bsz, q_len, _ = hidden_states.size() | |
if self.q_lora_rank is None: | |
q = self.q_proj(hidden_states) | |
else: | |
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) | |
q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) | |
q_nope, q_pe = torch.split( | |
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) | |
# Flash attention requires the input to have the shape | |
# batch_size x seq_length x head_dim x hidden_dim | |
# therefore we just need to keep the original shape | |
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) | |
compressed_kv, k_pe = torch.split( | |
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) | |
k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) | |
kv = ( | |
self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).view( | |
bsz, q_len, self.num_heads, | |
self.qk_nope_head_dim + self.v_head_dim).transpose(1, 2)) | |
k_nope, value_states = torch.split( | |
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) | |
kv_seq_len = value_states.shape[-2] | |
kv_seq_len = value_states.shape[-2] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, | |
self.layer_idx) | |
assert position_ids is not None, '`position_ids` should not be None.' | |
if self.training: | |
cos, sin = self.rotary_emb( | |
value_states, seq_len=position_ids.max() + 1) | |
else: | |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) | |
query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) | |
query_states[:, :, :, :self.qk_nope_head_dim] = q_nope | |
query_states[:, :, :, self.qk_nope_head_dim:] = q_pe | |
key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) | |
key_states[:, :, :, :self.qk_nope_head_dim] = k_nope | |
key_states[:, :, :, self.qk_nope_head_dim:] = k_pe | |
if self.q_head_dim != self.v_head_dim: | |
value_states = F.pad(value_states, | |
[0, self.q_head_dim - self.v_head_dim]) | |
if past_key_value is not None: | |
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs) | |
# In PEFT, usually we cast the layer norms in float32 for training | |
# stability reasons therefore the input hidden states gets silently | |
# casted in float32. Hence, we need cast them back in the correct dtype | |
# just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not | |
# cast the LayerNorms in fp32. (DeepseekV2RMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
# Handle the case where the model is quantized | |
if hasattr(self.config, '_pre_quantization_dtype'): | |
target_dtype = self.config._pre_quantization_dtype | |
elif torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
else: | |
target_dtype = self.q_a_proj.weight.dtype | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
# Reashape to the expected shape for Flash Attention | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
# ----------------- varlen flash attention forward ----------------------# | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
if not self._flash_attn_uses_top_left_mask: | |
causal = self.is_causal | |
else: | |
causal = self.is_causal and q_len != 1 | |
if is_training: | |
attn_output = varlen_flash_attn( | |
query_states, | |
key_states, | |
value_states, | |
cumulative_len, | |
max_seqlen, | |
causal=causal, | |
dropout_p=dropout_rate, | |
training=True) | |
else: | |
attn_output = flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
causal=causal, | |
dropout_p=dropout_rate, | |
training=False) | |
# ---------------- varlen flash attention forward end ------------------ # | |
if self.q_head_dim != self.v_head_dim: | |
attn_output = attn_output[:, :, :, :self.v_head_dim] | |
attn_output = attn_output.reshape(bsz, q_len, | |
self.num_heads * self.v_head_dim) | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |