Spaces:
Runtime error
Runtime error
File size: 4,619 Bytes
476ac07 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
import torch.distributed as dist
from transformers.models.cohere.modeling_cohere import apply_rotary_pos_emb
from xtuner.parallel.sequence import get_sequence_parallel_world_size
from xtuner.parallel.sequence.attention import (
post_process_for_sequence_parallel_attn,
pre_process_for_sequence_parallel_attn)
try:
from transformers.cache_utils import Cache
except ImportError:
class Cache:
pass
def cohere_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
output_attentions = False
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim)
if self.use_qk_norm:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin)
past_key_value = getattr(self, 'past_key_value', past_key_value)
if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for
# the static cache
cache_kwargs = {
'sin': sin,
'cos': cos,
'cache_position': cache_position
}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires
# the layout [batch_size, sequence_length, num_heads, head_dim].
# We would need to refactor the KV cache to be able to avoid many of
# these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# Ignore copy
# In PEFT, usually we cast the layer norms in float32 for training
# stability reasons therefore the input hidden states gets silently
# casted in float32. Hence, we need cast them back in the correct dtype
# just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
enable_sequence_parallel = (
dist.is_initialized() and get_sequence_parallel_world_size() > 1
and self.training)
if enable_sequence_parallel:
query_states, key_states, value_states = \
pre_process_for_sequence_parallel_attn(
query_states, key_states, value_states)
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate)
if enable_sequence_parallel:
attn_output = post_process_for_sequence_parallel_attn(attn_output)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
|