Updated flash attention usage
Browse files
configuration_megatron_gpt.py
CHANGED
@@ -81,7 +81,7 @@ class MegatronGPTConfig(PretrainedConfig):
|
|
81 |
Whether to calculate and apply the relative position bias within the attention function.
|
82 |
If this is False, then model.generate will require you to calculate the triangular attention
|
83 |
mask and pass it through in the attention mask.
|
84 |
-
|
85 |
When calculating attention, whether to attempt to use flash attention if it's installed, or to always skip and use the regular method.
|
86 |
rope_scaling (`Dict`, *optional*):
|
87 |
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
@@ -120,7 +120,7 @@ class MegatronGPTConfig(PretrainedConfig):
|
|
120 |
eos_token_id=2,
|
121 |
tie_word_embeddings=False,
|
122 |
rope_scaling=None,
|
123 |
-
|
124 |
**kwargs,
|
125 |
):
|
126 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
@@ -144,7 +144,7 @@ class MegatronGPTConfig(PretrainedConfig):
|
|
144 |
self.use_cache = use_cache
|
145 |
self.self_attention_relative_position_bias = self_attention_relative_position_bias
|
146 |
self.tie_word_embeddings = tie_word_embeddings
|
147 |
-
self.
|
148 |
self.rope_scaling = rope_scaling
|
149 |
self._rope_scaling_validation()
|
150 |
|
|
|
81 |
Whether to calculate and apply the relative position bias within the attention function.
|
82 |
If this is False, then model.generate will require you to calculate the triangular attention
|
83 |
mask and pass it through in the attention mask.
|
84 |
+
use_flash_attention (`bool`, *optional*, defaults to `False`):
|
85 |
When calculating attention, whether to attempt to use flash attention if it's installed, or to always skip and use the regular method.
|
86 |
rope_scaling (`Dict`, *optional*):
|
87 |
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
|
|
|
120 |
eos_token_id=2,
|
121 |
tie_word_embeddings=False,
|
122 |
rope_scaling=None,
|
123 |
+
use_flash_attention=False,
|
124 |
**kwargs,
|
125 |
):
|
126 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
144 |
self.use_cache = use_cache
|
145 |
self.self_attention_relative_position_bias = self_attention_relative_position_bias
|
146 |
self.tie_word_embeddings = tie_word_embeddings
|
147 |
+
self.use_flash_attention = use_flash_attention
|
148 |
self.rope_scaling = rope_scaling
|
149 |
self._rope_scaling_validation()
|
150 |
|
modeling_megatron_gpt.py
CHANGED
@@ -222,7 +222,7 @@ class MegatronGPTAttention(nn.Module):
|
|
222 |
present = (key, value) if use_cache else None
|
223 |
|
224 |
# Compute attention
|
225 |
-
if not HAS_FLASH or output_attentions or head_mask is not None or self.config.
|
226 |
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
227 |
else:
|
228 |
attn_output = self._flash_attn(query, key, value, attention_mask)
|
|
|
222 |
present = (key, value) if use_cache else None
|
223 |
|
224 |
# Compute attention
|
225 |
+
if not HAS_FLASH or output_attentions or head_mask is not None or not self.config.use_flash_attention:
|
226 |
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
227 |
else:
|
228 |
attn_output = self._flash_attn(query, key, value, attention_mask)
|