yangapku commited on
Commit
1241954
·
1 Parent(s): 5db622a

update batch inference

Browse files
Files changed (1) hide show
  1. modeling_qwen.py +32 -20
modeling_qwen.py CHANGED
@@ -35,6 +35,8 @@ from torch import nn
35
  SUPPORT_CUDA = torch.cuda.is_available()
36
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
 
 
38
 
39
  from .configuration_qwen import QWenConfig
40
  from .qwen_generation_utils import (
@@ -186,7 +188,7 @@ class FlashSelfAttention(torch.nn.Module):
186
  device=q.device,
187
  )
188
 
189
- if attention_mask is not None:
190
  k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
191
  if q.size(0) == v.size(0):
192
  q = q[indices_k]
@@ -222,7 +224,7 @@ class FlashSelfAttention(torch.nn.Module):
222
  softmax_scale=self.softmax_scale,
223
  causal=is_causal,
224
  )
225
- if attention_mask is not None and seqlen_q == seqlen_k:
226
  output = self.pad_input(output, indices_k, batch_size, seqlen_out)
227
  else:
228
  new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
@@ -451,7 +453,7 @@ class QWenAttention(nn.Module):
451
  def forward(
452
  self,
453
  hidden_states: Optional[Tuple[torch.FloatTensor]],
454
- rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
455
  registered_causal_mask: Optional[torch.Tensor] = None,
456
  layer_past: Optional[Tuple[torch.Tensor]] = None,
457
  attention_mask: Optional[torch.FloatTensor] = None,
@@ -543,11 +545,7 @@ class QWenAttention(nn.Module):
543
  and query.is_cuda
544
  ):
545
  q, k, v = query, key, value
546
- context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
547
-
548
- # b s h d -> b s (h d)
549
- context_layer = context_layer.flatten(2,3).contiguous()
550
-
551
  else:
552
  query = query.permute(0, 2, 1, 3)
553
  if not self.use_cache_quantization:
@@ -561,12 +559,28 @@ class QWenAttention(nn.Module):
561
  and not query.is_cuda
562
  ):
563
  raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
564
- attn_output, attn_weight = self._attn(
565
- query, key, value, registered_causal_mask, attention_mask, head_mask
566
- )
567
- context_layer = self._merge_heads(
568
- attn_output, self.num_heads, self.head_dim
569
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
571
  attn_output = self.c_proj(context_layer)
572
 
@@ -624,7 +638,7 @@ class QWenBlock(nn.Module):
624
  def forward(
625
  self,
626
  hidden_states: Optional[Tuple[torch.FloatTensor]],
627
- rotary_pos_emb_list: Optional[List[torch.Tensor]] = None,
628
  registered_causal_mask: Optional[torch.Tensor] = None,
629
  layer_past: Optional[Tuple[torch.Tensor]] = None,
630
  attention_mask: Optional[torch.FloatTensor] = None,
@@ -890,11 +904,9 @@ class QWenModel(QWenPreTrainedModel):
890
  ntk_alpha = self.get_ntk_alpha(kv_seq_len)
891
  ntk_alpha_list.append(ntk_alpha)
892
  self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
893
-
894
- rotary_pos_emb_list = []
895
- for ntk_alpha in ntk_alpha_list:
896
- rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
897
- rotary_pos_emb_list.append(rotary_pos_emb)
898
 
899
  hidden_states = self.drop(hidden_states)
900
  output_shape = input_shape + (hidden_states.size(-1),)
 
35
  SUPPORT_CUDA = torch.cuda.is_available()
36
  SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
  SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
38
+ SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
39
+
40
 
41
  from .configuration_qwen import QWenConfig
42
  from .qwen_generation_utils import (
 
188
  device=q.device,
189
  )
190
 
191
+ if batch_size > 1 and attention_mask is not None:
192
  k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
193
  if q.size(0) == v.size(0):
194
  q = q[indices_k]
 
224
  softmax_scale=self.softmax_scale,
225
  causal=is_causal,
226
  )
227
+ if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
228
  output = self.pad_input(output, indices_k, batch_size, seqlen_out)
229
  else:
230
  new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
 
453
  def forward(
454
  self,
455
  hidden_states: Optional[Tuple[torch.FloatTensor]],
456
+ rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
457
  registered_causal_mask: Optional[torch.Tensor] = None,
458
  layer_past: Optional[Tuple[torch.Tensor]] = None,
459
  attention_mask: Optional[torch.FloatTensor] = None,
 
545
  and query.is_cuda
546
  ):
547
  q, k, v = query, key, value
548
+ attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
 
 
 
 
549
  else:
550
  query = query.permute(0, 2, 1, 3)
551
  if not self.use_cache_quantization:
 
559
  and not query.is_cuda
560
  ):
561
  raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED)
562
+
563
+ if not self.use_cache_quantization and SUPPORT_TORCH2:
564
+ causal_mask = registered_causal_mask[
565
+ :, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2)
566
+ ]
567
+ if attention_mask is not None:
568
+ attention_mask = attention_mask.expand(
569
+ -1, -1, causal_mask.size(2), -1
570
+ ).masked_fill(~causal_mask, torch.finfo(query.dtype).min)
571
+ else:
572
+ attention_mask = causal_mask
573
+ attn_output = F.scaled_dot_product_attention(
574
+ query, key, value, attn_mask=attention_mask
575
+ ).transpose(1, 2)
576
+ attn_weight = None
577
+ else:
578
+ attn_output, attn_weight = self._attn(
579
+ query, key, value, registered_causal_mask, attention_mask, head_mask
580
+ )
581
+ context_layer = self._merge_heads(
582
+ attn_output, self.num_heads, self.head_dim
583
+ )
584
 
585
  attn_output = self.c_proj(context_layer)
586
 
 
638
  def forward(
639
  self,
640
  hidden_states: Optional[Tuple[torch.FloatTensor]],
641
+ rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None,
642
  registered_causal_mask: Optional[torch.Tensor] = None,
643
  layer_past: Optional[Tuple[torch.Tensor]] = None,
644
  attention_mask: Optional[torch.FloatTensor] = None,
 
904
  ntk_alpha = self.get_ntk_alpha(kv_seq_len)
905
  ntk_alpha_list.append(ntk_alpha)
906
  self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list
907
+ rotary_pos_emb_list = [
908
+ self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list
909
+ ]
 
 
910
 
911
  hidden_states = self.drop(hidden_states)
912
  output_shape = input_shape + (hidden_states.size(-1),)