Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Optional, Tuple | |
import torch | |
import torch.distributed as dist | |
from mmengine import MessageHub | |
from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb, | |
repeat_kv) | |
from transformers.utils import is_flash_attn_greater_or_equal_2_10 | |
from .attention import (SUPPORT_FLASH2, flash_attn_w_mask, flash_attn_wo_mask, | |
varlen_flash_attn) | |
from .triton_kernels import apply_rotary_emb | |
try: | |
from transformers.cache_utils import Cache | |
except ImportError: | |
class Cache: | |
pass | |
def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
"""The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) | |
to (batch, seqlen, num_attention_heads, head_dim)""" | |
batch, slen, num_key_value_heads, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, :, | |
None, :].expand(batch, slen, | |
num_key_value_heads, n_rep, | |
head_dim) | |
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, | |
head_dim) | |
def llama_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
**kwargs, | |
): | |
# Modified from https://github.com/huggingface/transformers/blob/66ce9593fdb8e340df546ddd0774eb444f17a12c/src/transformers/models/llama/modeling_llama.py#L422 # noqa:E501 | |
output_attentions = False | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
# Flash attention requires the input to have the shape | |
# batch_size x seq_length x head_dim x hidden_dim | |
# therefore we just need to keep the original shape | |
query_states = query_states.view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin) | |
past_key_value = getattr(self, 'past_key_value', past_key_value) | |
if past_key_value is not None: | |
# sin and cos are specific to RoPE models; | |
# cache_position needed for the static cache | |
cache_kwargs = { | |
'sin': sin, | |
'cos': cos, | |
'cache_position': cache_position | |
} | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs) | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
assert SUPPORT_FLASH2 | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
# In PEFT, usually we cast the layer norms in float32 for training | |
# stability reasons therefore the input hidden states gets silently | |
# casted in float32. Hence, we need cast them back in the correct dtype | |
# just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not | |
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
if torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
# Handle the case where the model is quantized | |
elif hasattr(self.config, '_pre_quantization_dtype'): | |
target_dtype = self.config._pre_quantization_dtype | |
else: | |
target_dtype = self.q_proj.weight.dtype | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
if is_flash_attn_greater_or_equal_2_10(): | |
causal = self.is_causal | |
else: | |
# TODO: Remove the `q_len != 1` check once Flash Attention for RoCm | |
# is bumped to 2.1. For details, please see the comment in | |
# LlamaFlashAttention2 __init__. | |
causal = self.is_causal and q_len != 1 | |
# the shape of attention_mask used by flash_attn and | |
# F.scaled_dot_product_attention are different | |
assert attention_mask is None or attention_mask.ndim == 2, \ | |
('When using flash_attn, attention_mask.ndim should equal to 2.' | |
f'But got attention_mask.shape = {attention_mask.shape}.' | |
'We can pass the `attn_implementation="flash_attention_2"` flag ' | |
'to `.from_pretrained` method when instantiating a Internlm2 ' | |
'model.') | |
if attention_mask is not None: | |
attn_output = flash_attn_w_mask( | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
causal=causal, | |
dropout_p=dropout_rate, | |
training=self.training) | |
else: | |
attn_output = flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
causal=causal, | |
dropout_p=dropout_rate, | |
training=self.training) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
def llama_attn_forward_legacy( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
**kwargs, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], | |
Optional[Tuple[torch.Tensor]]]: | |
# Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501 | |
if 'padding_mask' in kwargs: | |
warnings.warn( | |
'Passing `padding_mask` is deprecated and will be removed in ' | |
'v4.37. Please make sure use `attention_mask` instead.`') | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
kv_seq_len = key_states.shape[-2] | |
if past_key_value is not None: | |
if self.layer_idx is None: | |
raise ValueError( | |
'The cache structure has changed since version v4.36. ' | |
f'If you are using {self.__class__.__name__} ' | |
'for auto-regressive decoding with k/v caching, ' | |
'please make sure to initialize the attention class ' | |
'with a layer index.') | |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, | |
self.layer_idx) | |
assert position_ids is not None | |
if self.training: | |
cos, sin = self.rotary_emb( | |
value_states, seq_len=position_ids.max() + 1) | |
else: | |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin, position_ids) | |
if past_key_value is not None: | |
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs) | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
assert SUPPORT_FLASH2 | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
# In PEFT, usually we cast the layer norms in float32 for training | |
# stability reasons therefore the input hidden states gets silently | |
# casted in float32. Hence, we need cast them back in the correct dtype | |
# just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not | |
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
if torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
# Handle the case where the model is quantized | |
elif hasattr(self.config, '_pre_quantization_dtype'): | |
target_dtype = self.config._pre_quantization_dtype | |
else: | |
target_dtype = self.q_proj.weight.dtype | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
if is_flash_attn_greater_or_equal_2_10(): | |
causal = self.is_causal | |
else: | |
# TODO: Remove the `q_len != 1` check once Flash Attention for RoCm | |
# is bumped to 2.1. For details, please see the comment in | |
# LlamaFlashAttention2 __init__. | |
causal = self.is_causal and q_len != 1 | |
# the shape of attention_mask used by flash_attn and | |
# F.scaled_dot_product_attention are different | |
assert attention_mask is None or attention_mask.ndim == 2, \ | |
('When using flash_attn, attention_mask.ndim should equal to 2.' | |
f'But got attention_mask.shape = {attention_mask.shape}.' | |
'We can pass the `attn_implementation="flash_attention_2"` flag ' | |
'to `.from_pretrained` method when instantiating a Internlm2 ' | |
'model.') | |
if attention_mask is not None: | |
attn_output = flash_attn_w_mask( | |
query_states, | |
key_states, | |
value_states, | |
attention_mask=attention_mask, | |
causal=causal, | |
dropout_p=dropout_rate, | |
training=self.training) | |
else: | |
attn_output = flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
causal=causal, | |
dropout_p=dropout_rate, | |
training=self.training) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
# Due to the implementation of the PyTorch version of flash attention, | |
# even when the output_attentions flag is set to True, it is not possible | |
# to return the attn_weights. | |
return attn_output, None, past_key_value | |
def llama_varlen_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
**kwargs, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], | |
Optional[Tuple[torch.Tensor]]]: | |
message_hub = MessageHub.get_instance('varlen_attn_args') | |
rank = dist.get_rank() | |
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') | |
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') | |
use_varlen_atten = (cumulative_len is not None) | |
if 'padding_mask' in kwargs: | |
warnings.warn('Passing `padding_mask` is deprecated and will be ' | |
'removed in v4.37. Please make sure use ' | |
'`attention_mask` instead.`') | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin) | |
past_key_value = getattr(self, 'past_key_value', past_key_value) | |
if past_key_value is not None: | |
# sin and cos are specific to RoPE models; | |
# cache_position needed for the static cache | |
cache_kwargs = { | |
'sin': sin, | |
'cos': cos, | |
'cache_position': cache_position | |
} | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs) | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
# repeat kv for sequence parallel | |
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) | |
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
# In PEFT, usually we cast the layer norms in float32 for training | |
# stability reasons therefore the input hidden states gets silently casted | |
# in float32. Hence, we need cast them back in the correct dtype | |
# just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not | |
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
if torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
# Handle the case where the model is quantized | |
elif hasattr(self.config, '_pre_quantization_dtype'): | |
target_dtype = self.config._pre_quantization_dtype | |
else: | |
target_dtype = self.q_proj.weight.dtype | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
assert SUPPORT_FLASH2 | |
if use_varlen_atten: | |
attn_output = varlen_flash_attn( | |
query_states, | |
key_states, | |
value_states, | |
cumulative_len, | |
max_seqlen, | |
causal=True, | |
dropout_p=dropout_rate, | |
training=self.training) | |
else: | |
attn_output = flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
causal=True, | |
training=self.training) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
return attn_output, None, past_key_value | |
def llama_varlen_attn_forward_legacy( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
**kwargs, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], | |
Optional[Tuple[torch.Tensor]]]: | |
message_hub = MessageHub.get_instance('varlen_attn_args') | |
rank = dist.get_rank() | |
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') | |
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') | |
use_varlen_atten = (cumulative_len is not None) | |
if 'padding_mask' in kwargs: | |
warnings.warn('Passing `padding_mask` is deprecated and will be ' | |
'removed in v4.37. Please make sure use ' | |
'`attention_mask` instead.`') | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim) | |
kv_seq_len = key_states.shape[-3] | |
if past_key_value is not None: | |
if self.layer_idx is None: | |
raise ValueError( | |
'The cache structure has changed since version v4.36. ' | |
f'If you are using {self.__class__.__name__} ' | |
'for auto-regressive decoding with k/v caching, ' | |
'please make sure to initialize the attention class ' | |
'with a layer index.') | |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, | |
self.layer_idx) | |
if use_varlen_atten: | |
cos, sin = self.rotary_emb(value_states, max_seqlen) | |
# position_ids (1, seq_len) | |
# cos, sin (1, seq_len, dim) -> (seq_len, dim) | |
cos = cos[position_ids].squeeze(0) | |
sin = sin[position_ids].squeeze(0) | |
query_states = apply_rotary_emb(query_states, cos, sin) | |
key_states = apply_rotary_emb(key_states, cos, sin) | |
else: | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb( | |
query_states, key_states, cos, sin, position_ids) | |
if past_key_value is not None: | |
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs) | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
# repeat kv for sequence parallel | |
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) | |
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
# In PEFT, usually we cast the layer norms in float32 for training | |
# stability reasons therefore the input hidden states gets silently casted | |
# in float32. Hence, we need cast them back in the correct dtype | |
# just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not | |
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
if torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
# Handle the case where the model is quantized | |
elif hasattr(self.config, '_pre_quantization_dtype'): | |
target_dtype = self.config._pre_quantization_dtype | |
else: | |
target_dtype = self.q_proj.weight.dtype | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
assert SUPPORT_FLASH2 | |
if use_varlen_atten: | |
attn_output = varlen_flash_attn( | |
query_states, | |
key_states, | |
value_states, | |
cumulative_len, | |
max_seqlen, | |
causal=True, | |
dropout_p=dropout_rate, | |
training=self.training) | |
else: | |
attn_output = flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
causal=True, | |
dropout_p=dropout_rate, | |
training=self.training) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
# Due to the implementation of the PyTorch version of flash attention, | |
# even when the output_attentions flag is set to True, it is not possible | |
# to return the attn_weights. | |
return attn_output, None, past_key_value | |