update modeling_yi.py
Browse files- modeling_yi.py +11 -13
modeling_yi.py
CHANGED
@@ -6,7 +6,6 @@ import torch.utils.checkpoint
|
|
6 |
from einops import repeat
|
7 |
from torch import nn
|
8 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
9 |
-
|
10 |
from transformers.activations import ACT2FN
|
11 |
from transformers.modeling_outputs import (
|
12 |
BaseModelOutputWithPast,
|
@@ -18,17 +17,17 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
18 |
from transformers.utils import (
|
19 |
add_start_docstrings,
|
20 |
add_start_docstrings_to_model_forward,
|
21 |
-
is_flash_attn_available,
|
22 |
logging,
|
23 |
replace_return_docstrings,
|
24 |
)
|
25 |
|
26 |
from .configuration_yi import YiConfig
|
27 |
|
28 |
-
|
29 |
-
|
30 |
from flash_attn import flash_attn_func
|
31 |
-
|
|
|
32 |
|
33 |
logger = logging.get_logger(__name__)
|
34 |
|
@@ -224,7 +223,6 @@ class YiAttention(nn.Module):
|
|
224 |
use_cache: bool = False,
|
225 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
226 |
bsz, q_len, _ = hidden_states.size()
|
227 |
-
flash_attn_available = is_flash_attn_available()
|
228 |
|
229 |
query_states = self.q_proj(hidden_states).view(
|
230 |
bsz, q_len, self.num_heads, self.head_dim
|
@@ -237,7 +235,7 @@ class YiAttention(nn.Module):
|
|
237 |
bsz, q_len, self.num_key_value_heads, self.head_dim
|
238 |
)
|
239 |
|
240 |
-
if not
|
241 |
if self.num_key_value_groups > 1:
|
242 |
key_states = repeat(
|
243 |
key_states, f"b n h d -> b n (h {self.num_key_value_groups}) d"
|
@@ -251,13 +249,13 @@ class YiAttention(nn.Module):
|
|
251 |
key_states = key_states.transpose(1, 2)
|
252 |
value_states = value_states.transpose(1, 2)
|
253 |
|
254 |
-
seq_dim = 1 if
|
255 |
kv_seq_len = key_states.shape[seq_dim]
|
256 |
if past_key_value is not None:
|
257 |
kv_seq_len += past_key_value[0].shape[seq_dim]
|
258 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
259 |
query_states, key_states = apply_rotary_pos_emb(
|
260 |
-
query_states, key_states, cos, sin, position_ids,
|
261 |
)
|
262 |
|
263 |
if past_key_value is not None:
|
@@ -267,7 +265,7 @@ class YiAttention(nn.Module):
|
|
267 |
|
268 |
past_key_value = (key_states, value_states) if use_cache else None
|
269 |
|
270 |
-
if
|
271 |
attn_output = flash_attn_func(
|
272 |
query_states, key_states, value_states, dropout_p=0.0, causal=True
|
273 |
)
|
@@ -308,7 +306,7 @@ class YiAttention(nn.Module):
|
|
308 |
f" {attn_output.size()}"
|
309 |
)
|
310 |
|
311 |
-
if not
|
312 |
attn_output = attn_output.transpose(1, 2)
|
313 |
|
314 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
@@ -541,7 +539,7 @@ class YiModel(YiPreTrainedModel):
|
|
541 |
def _prepare_decoder_attention_mask(
|
542 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
543 |
):
|
544 |
-
input_shape = input_ids.shape
|
545 |
# create causal mask
|
546 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
547 |
combined_attention_mask = None
|
@@ -631,7 +629,7 @@ class YiModel(YiPreTrainedModel):
|
|
631 |
if inputs_embeds is None:
|
632 |
inputs_embeds = self.embed_tokens(input_ids)
|
633 |
|
634 |
-
if not is_flash_attn_available
|
635 |
# embed positions
|
636 |
if attention_mask is None:
|
637 |
attention_mask = torch.ones(
|
|
|
6 |
from einops import repeat
|
7 |
from torch import nn
|
8 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|
|
9 |
from transformers.activations import ACT2FN
|
10 |
from transformers.modeling_outputs import (
|
11 |
BaseModelOutputWithPast,
|
|
|
17 |
from transformers.utils import (
|
18 |
add_start_docstrings,
|
19 |
add_start_docstrings_to_model_forward,
|
|
|
20 |
logging,
|
21 |
replace_return_docstrings,
|
22 |
)
|
23 |
|
24 |
from .configuration_yi import YiConfig
|
25 |
|
26 |
+
is_flash_attn_available = True
|
27 |
+
try:
|
28 |
from flash_attn import flash_attn_func
|
29 |
+
except Exception:
|
30 |
+
is_flash_attn_available = False
|
31 |
|
32 |
logger = logging.get_logger(__name__)
|
33 |
|
|
|
223 |
use_cache: bool = False,
|
224 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
225 |
bsz, q_len, _ = hidden_states.size()
|
|
|
226 |
|
227 |
query_states = self.q_proj(hidden_states).view(
|
228 |
bsz, q_len, self.num_heads, self.head_dim
|
|
|
235 |
bsz, q_len, self.num_key_value_heads, self.head_dim
|
236 |
)
|
237 |
|
238 |
+
if not is_flash_attn_available:
|
239 |
if self.num_key_value_groups > 1:
|
240 |
key_states = repeat(
|
241 |
key_states, f"b n h d -> b n (h {self.num_key_value_groups}) d"
|
|
|
249 |
key_states = key_states.transpose(1, 2)
|
250 |
value_states = value_states.transpose(1, 2)
|
251 |
|
252 |
+
seq_dim = 1 if is_flash_attn_available else 2
|
253 |
kv_seq_len = key_states.shape[seq_dim]
|
254 |
if past_key_value is not None:
|
255 |
kv_seq_len += past_key_value[0].shape[seq_dim]
|
256 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
257 |
query_states, key_states = apply_rotary_pos_emb(
|
258 |
+
query_states, key_states, cos, sin, position_ids, is_flash_attn_available
|
259 |
)
|
260 |
|
261 |
if past_key_value is not None:
|
|
|
265 |
|
266 |
past_key_value = (key_states, value_states) if use_cache else None
|
267 |
|
268 |
+
if is_flash_attn_available:
|
269 |
attn_output = flash_attn_func(
|
270 |
query_states, key_states, value_states, dropout_p=0.0, causal=True
|
271 |
)
|
|
|
306 |
f" {attn_output.size()}"
|
307 |
)
|
308 |
|
309 |
+
if not is_flash_attn_available:
|
310 |
attn_output = attn_output.transpose(1, 2)
|
311 |
|
312 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
|
|
539 |
def _prepare_decoder_attention_mask(
|
540 |
self, attention_mask, input_ids, inputs_embeds, past_key_values_length
|
541 |
):
|
542 |
+
input_shape = input_ids.shape if input_ids else inputs_embeds.shape[:-1]
|
543 |
# create causal mask
|
544 |
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
545 |
combined_attention_mask = None
|
|
|
629 |
if inputs_embeds is None:
|
630 |
inputs_embeds = self.embed_tokens(input_ids)
|
631 |
|
632 |
+
if not is_flash_attn_available:
|
633 |
# embed positions
|
634 |
if attention_mask is None:
|
635 |
attention_mask = torch.ones(
|