Update modeling_plm.py
Browse files- modeling_plm.py +4 -4
modeling_plm.py
CHANGED
@@ -408,7 +408,7 @@ class PLMAttention(nn.Module):
|
|
408 |
|
409 |
class PLMFlashAttention2(PLMAttention):
|
410 |
"""
|
411 |
-
|
412 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
413 |
flash attention and deal with padding tokens in case the input contains any of them.
|
414 |
"""
|
@@ -431,7 +431,7 @@ class PLMFlashAttention2(PLMAttention):
|
|
431 |
use_cache: bool = False,
|
432 |
**kwargs,
|
433 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
434 |
-
#
|
435 |
|
436 |
if "padding_mask" in kwargs:
|
437 |
warnings.warn(
|
@@ -509,7 +509,7 @@ class PLMFlashAttention2(PLMAttention):
|
|
509 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
510 |
# cast them back in the correct dtype just to be sure everything works as expected.
|
511 |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
512 |
-
# in fp32. (
|
513 |
|
514 |
input_dtype = query_states.dtype
|
515 |
if input_dtype == torch.float32:
|
@@ -587,7 +587,7 @@ class PLMFlashAttention2(PLMAttention):
|
|
587 |
if not self._flash_attn_uses_top_left_mask:
|
588 |
causal = self.is_causal
|
589 |
else:
|
590 |
-
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in
|
591 |
causal = self.is_causal and query_length != 1
|
592 |
|
593 |
# Contains at least one padding token in the sequence
|
|
|
408 |
|
409 |
class PLMFlashAttention2(PLMAttention):
|
410 |
"""
|
411 |
+
PLM flash attention module. This module inherits from `PLMAttention` as the weights of the module stays
|
412 |
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
413 |
flash attention and deal with padding tokens in case the input contains any of them.
|
414 |
"""
|
|
|
431 |
use_cache: bool = False,
|
432 |
**kwargs,
|
433 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
434 |
+
# PLMFlashAttention2 attention does not support output_attentions
|
435 |
|
436 |
if "padding_mask" in kwargs:
|
437 |
warnings.warn(
|
|
|
509 |
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
510 |
# cast them back in the correct dtype just to be sure everything works as expected.
|
511 |
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
512 |
+
# in fp32. (PLMV2RMSNorm handles it correctly)
|
513 |
|
514 |
input_dtype = query_states.dtype
|
515 |
if input_dtype == torch.float32:
|
|
|
587 |
if not self._flash_attn_uses_top_left_mask:
|
588 |
causal = self.is_causal
|
589 |
else:
|
590 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in PLMFlashAttention2 __init__.
|
591 |
causal = self.is_causal and query_length != 1
|
592 |
|
593 |
# Contains at least one padding token in the sequence
|