Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional | |
import torch | |
import torch.distributed as dist | |
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb | |
from xtuner.parallel.sequence import get_sequence_parallel_world_size | |
from xtuner.parallel.sequence.attention import ( | |
post_process_for_sequence_parallel_attn, | |
pre_process_for_sequence_parallel_attn) | |
try: | |
from transformers.cache_utils import Cache | |
except ImportError: | |
class Cache: | |
pass | |
def cohere_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Cache] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
cache_position: Optional[torch.LongTensor] = None, | |
**kwargs, | |
): | |
output_attentions = False | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim) | |
if self.use_qk_norm: | |
query_states = self.q_norm(query_states) | |
key_states = self.k_norm(key_states) | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin) | |
past_key_value = getattr(self, 'past_key_value', past_key_value) | |
if past_key_value is not None: | |
# sin and cos are specific to RoPE models; position_ids needed for | |
# the static cache | |
cache_kwargs = { | |
'sin': sin, | |
'cos': cos, | |
'cache_position': cache_position | |
} | |
key_states, value_states = past_key_value.update( | |
key_states, value_states, self.layer_idx, cache_kwargs) | |
# TODO: These transpose are quite inefficient but Flash Attention requires | |
# the layout [batch_size, sequence_length, num_heads, head_dim]. | |
# We would need to refactor the KV cache to be able to avoid many of | |
# these transpose/reshape/view. | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
dropout_rate = self.attention_dropout if self.training else 0.0 | |
# Ignore copy | |
# In PEFT, usually we cast the layer norms in float32 for training | |
# stability reasons therefore the input hidden states gets silently | |
# casted in float32. Hence, we need cast them back in the correct dtype | |
# just to be sure everything works as expected. | |
# This might slowdown training & inference so it is recommended to not | |
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) | |
input_dtype = query_states.dtype | |
if input_dtype == torch.float32: | |
if torch.is_autocast_enabled(): | |
target_dtype = torch.get_autocast_gpu_dtype() | |
# Handle the case where the model is quantized | |
elif hasattr(self.config, '_pre_quantization_dtype'): | |
target_dtype = self.config._pre_quantization_dtype | |
else: | |
target_dtype = self.q_proj.weight.dtype | |
query_states = query_states.to(target_dtype) | |
key_states = key_states.to(target_dtype) | |
value_states = value_states.to(target_dtype) | |
enable_sequence_parallel = ( | |
dist.is_initialized() and get_sequence_parallel_world_size() > 1 | |
and self.training) | |
if enable_sequence_parallel: | |
query_states, key_states, value_states = \ | |
pre_process_for_sequence_parallel_attn( | |
query_states, key_states, value_states) | |
attn_output = self._flash_attention_forward( | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
q_len, | |
dropout=dropout_rate) | |
if enable_sequence_parallel: | |
attn_output = post_process_for_sequence_parallel_attn(attn_output) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |