PyTorch
English
Chinese
plm
custom_code
jjw0126 commited on
Commit
f78105b
·
verified ·
1 Parent(s): 014e6dd

Update modeling_plm.py

Browse files
Files changed (1) hide show
  1. modeling_plm.py +4 -4
modeling_plm.py CHANGED
@@ -408,7 +408,7 @@ class PLMAttention(nn.Module):
408
 
409
  class PLMFlashAttention2(PLMAttention):
410
  """
411
- DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays
412
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
413
  flash attention and deal with padding tokens in case the input contains any of them.
414
  """
@@ -431,7 +431,7 @@ class PLMFlashAttention2(PLMAttention):
431
  use_cache: bool = False,
432
  **kwargs,
433
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
434
- # DeepseekV2FlashAttention2 attention does not support output_attentions
435
 
436
  if "padding_mask" in kwargs:
437
  warnings.warn(
@@ -509,7 +509,7 @@ class PLMFlashAttention2(PLMAttention):
509
  # therefore the input hidden states gets silently casted in float32. Hence, we need
510
  # cast them back in the correct dtype just to be sure everything works as expected.
511
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
512
- # in fp32. (DeepseekV2RMSNorm handles it correctly)
513
 
514
  input_dtype = query_states.dtype
515
  if input_dtype == torch.float32:
@@ -587,7 +587,7 @@ class PLMFlashAttention2(PLMAttention):
587
  if not self._flash_attn_uses_top_left_mask:
588
  causal = self.is_causal
589
  else:
590
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__.
591
  causal = self.is_causal and query_length != 1
592
 
593
  # Contains at least one padding token in the sequence
 
408
 
409
  class PLMFlashAttention2(PLMAttention):
410
  """
411
+ PLM flash attention module. This module inherits from `PLMAttention` as the weights of the module stays
412
  untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
413
  flash attention and deal with padding tokens in case the input contains any of them.
414
  """
 
431
  use_cache: bool = False,
432
  **kwargs,
433
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
434
+ # PLMFlashAttention2 attention does not support output_attentions
435
 
436
  if "padding_mask" in kwargs:
437
  warnings.warn(
 
509
  # therefore the input hidden states gets silently casted in float32. Hence, we need
510
  # cast them back in the correct dtype just to be sure everything works as expected.
511
  # This might slowdown training & inference so it is recommended to not cast the LayerNorms
512
+ # in fp32. (PLMV2RMSNorm handles it correctly)
513
 
514
  input_dtype = query_states.dtype
515
  if input_dtype == torch.float32:
 
587
  if not self._flash_attn_uses_top_left_mask:
588
  causal = self.is_causal
589
  else:
590
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in PLMFlashAttention2 __init__.
591
  causal = self.is_causal and query_length != 1
592
 
593
  # Contains at least one padding token in the sequence