itlevy commited on
Commit
186a08a
·
verified ·
1 Parent(s): e9d7c68

flash_attention_utils_backward_compat (#2)

Browse files

- flash_attention_utils_backward_compat (a9c47229964d0d7b30c4c544351fb96760ad63bf)

modeling_decilm.py CHANGED
@@ -35,12 +35,12 @@ from transformers.utils import (
35
  replace_return_docstrings,
36
  )
37
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
38
- from transformers.modeling_flash_attention_utils import _flash_attention_forward
39
 
40
  from .configuration_decilm import DeciLMConfig, AttentionConfig, FFNConfig
41
  from .transformers_4_44_2__activations import ACT2FN
42
  from .transformers_4_44_2__cache_utils import Cache, StaticCache
43
  from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter
 
44
  from .transformers_4_44_2__modeling_outputs import (
45
  BaseModelOutputWithPast,
46
  CausalLMOutputWithPast,
@@ -1664,3 +1664,4 @@ class DeciLMLinearAttention(nn.Module):
1664
 
1665
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1666
  return self.linear_attn.forward(x)
 
 
35
  replace_return_docstrings,
36
  )
37
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
 
38
 
39
  from .configuration_decilm import DeciLMConfig, AttentionConfig, FFNConfig
40
  from .transformers_4_44_2__activations import ACT2FN
41
  from .transformers_4_44_2__cache_utils import Cache, StaticCache
42
  from .transformers_4_44_2__modeling_attn_mask_utils import AttentionMaskConverter
43
+ from .transformers_4_44_2__modeling_flash_attention_utils_backward_compat import _flash_attention_forward
44
  from .transformers_4_44_2__modeling_outputs import (
45
  BaseModelOutputWithPast,
46
  CausalLMOutputWithPast,
 
1664
 
1665
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1666
  return self.linear_attn.forward(x)
1667
+
transformers_4_44_2__modeling_flash_attention_utils_backward_compat.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import os
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
21
+ import torch.nn.functional as F
22
+
23
+ from transformers.utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
24
+
25
+
26
+ if is_flash_attn_2_available():
27
+ try:
28
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
29
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
30
+ _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
31
+ except ImportError:
32
+ raise "Unable to import flash_attn"
33
+
34
+
35
+ def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
36
+ """
37
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
38
+
39
+ Arguments:
40
+ attention_mask (`torch.Tensor`):
41
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
42
+
43
+ Return:
44
+ indices (`torch.Tensor`):
45
+ The indices of non-masked tokens from the flattened input sequence.
46
+ cu_seqlens (`torch.Tensor`):
47
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
48
+ max_seqlen_in_batch (`int`):
49
+ Maximum sequence length in batch.
50
+ """
51
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
52
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
53
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
54
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
55
+ return (
56
+ indices,
57
+ cu_seqlens,
58
+ max_seqlen_in_batch,
59
+ )
60
+
61
+
62
+ def _upad_input(
63
+ query_layer: torch.Tensor,
64
+ key_layer: torch.Tensor,
65
+ value_layer: torch.Tensor,
66
+ attention_mask: torch.Tensor,
67
+ query_length: int,
68
+ ):
69
+ """
70
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
71
+
72
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
73
+ tensors for query, key, value tensors.
74
+
75
+ Arguments:
76
+ query_layer (`torch.Tensor`):
77
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
78
+ key_layer (`torch.Tensor`):
79
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
80
+ value_layer (`torch.Tensor`):
81
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
82
+ attention_mask (`torch.Tensor`):
83
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
84
+ query_length (`int`):
85
+ Target length.
86
+
87
+ Return:
88
+ query_layer (`torch.Tensor`):
89
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
90
+ key_layer (`torch.Tensor`):
91
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
92
+ value_layer (`torch.Tensor`):
93
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
94
+ indices_q (`torch.Tensor`):
95
+ The indices of non-masked tokens from the flattened input target sequence.
96
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
97
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
98
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
99
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
100
+ """
101
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
102
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
103
+
104
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
105
+ value_layer = index_first_axis(
106
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
107
+ )
108
+ if query_length == kv_seq_len:
109
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
110
+ cu_seqlens_q = cu_seqlens_k
111
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
112
+ indices_q = indices_k
113
+ elif query_length == 1:
114
+ max_seqlen_in_batch_q = 1
115
+ cu_seqlens_q = torch.arange(
116
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
117
+ ) # There is a memcpy here, that is very bad.
118
+ indices_q = cu_seqlens_q[:-1]
119
+ query_layer = query_layer.squeeze(1)
120
+ else:
121
+ # The -q_len: slice assumes left padding.
122
+ attention_mask = attention_mask[:, -query_length:]
123
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
124
+
125
+ return (
126
+ query_layer,
127
+ key_layer,
128
+ value_layer,
129
+ indices_q,
130
+ (cu_seqlens_q, cu_seqlens_k),
131
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
132
+ )
133
+
134
+
135
+ def prepare_fa2_from_position_ids(query, key, value, position_ids):
136
+ """
137
+ This function returns necessary arguments to call `flash_attn_varlen_func`.
138
+ All three query, key, value states will be flattened.
139
+ Cummulative lengths of each examples in the batch will be extracted from position_ids.
140
+
141
+ NOTE: ideally cummulative lengths should be prepared at the data collator stage
142
+
143
+ Arguments:
144
+ query (`torch.Tensor`):
145
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
146
+ key (`torch.Tensor`):
147
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
148
+ value (`torch.Tensor`):
149
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
150
+ position_ids (`torch.Tensor`):
151
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
152
+
153
+ Return:
154
+ query (`torch.Tensor`):
155
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
156
+ key (`torch.Tensor`):
157
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
158
+ value (`torch.Tensor`):
159
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
160
+ indices_q (`torch.Tensor`):
161
+ The indices of non-masked tokens from the flattened input target sequence.
162
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
163
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
164
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
165
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
166
+ """
167
+ query = query.view(-1, query.size(-2), query.size(-1))
168
+ key = key.view(-1, key.size(-2), key.size(-1))
169
+ value = value.view(-1, value.size(-2), value.size(-1))
170
+ position_ids = position_ids.flatten()
171
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
172
+
173
+ cu_seq_lens = torch.cat(
174
+ (
175
+ indices_q[position_ids == 0],
176
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
177
+ )
178
+ )
179
+
180
+ max_length = position_ids.max() + 1
181
+
182
+ return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
183
+
184
+
185
+ def _flash_attention_forward(
186
+ query_states: torch.Tensor,
187
+ key_states: torch.Tensor,
188
+ value_states: torch.Tensor,
189
+ attention_mask: torch.Tensor,
190
+ query_length: int,
191
+ is_causal: bool,
192
+ dropout: float = 0.0,
193
+ position_ids: Optional[torch.Tensor] = None,
194
+ softmax_scale: Optional[float] = None,
195
+ sliding_window: Optional[int] = None,
196
+ use_top_left_mask: bool = False,
197
+ softcap: Optional[float] = None,
198
+ deterministic: bool = None,
199
+ ):
200
+ """
201
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
202
+ first unpad the input, then computes the attention scores and pad the final attention scores.
203
+
204
+ Args:
205
+ query_states (`torch.Tensor`):
206
+ Input query states to be passed to Flash Attention API
207
+ key_states (`torch.Tensor`):
208
+ Input key states to be passed to Flash Attention API
209
+ value_states (`torch.Tensor`):
210
+ Input value states to be passed to Flash Attention API
211
+ attention_mask (`torch.Tensor`):
212
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
213
+ position of padding tokens and 1 for the position of non-padding tokens.
214
+ dropout (`float`):
215
+ Attention dropout
216
+ softmax_scale (`float`, *optional*):
217
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
218
+ use_top_left_mask (`bool`, defaults to `False`):
219
+ flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
220
+ softcap (`float`, *optional*):
221
+ Softcap for the attention logits, used e.g. in gemma2.
222
+ deterministic (`bool`, *optional*):
223
+ Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
224
+ """
225
+ if not use_top_left_mask:
226
+ causal = is_causal
227
+ else:
228
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
229
+ causal = is_causal and query_length != 1
230
+
231
+ # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
232
+ use_sliding_windows = (
233
+ _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
234
+ )
235
+ flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
236
+
237
+ if is_flash_attn_greater_or_equal("2.4.1"):
238
+ if deterministic is None:
239
+ deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
240
+ flash_kwargs["deterministic"] = deterministic
241
+
242
+ if softcap is not None:
243
+ flash_kwargs["softcap"] = softcap
244
+
245
+ # Contains at least one padding token in the sequence
246
+ if attention_mask is not None:
247
+ batch_size = query_states.shape[0]
248
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
249
+ query_states, key_states, value_states, attention_mask, query_length
250
+ )
251
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
252
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
253
+
254
+ attn_output_unpad = flash_attn_varlen_func(
255
+ query_states,
256
+ key_states,
257
+ value_states,
258
+ cu_seqlens_q=cu_seqlens_q,
259
+ cu_seqlens_k=cu_seqlens_k,
260
+ max_seqlen_q=max_seqlen_in_batch_q,
261
+ max_seqlen_k=max_seqlen_in_batch_k,
262
+ dropout_p=dropout,
263
+ softmax_scale=softmax_scale,
264
+ causal=causal,
265
+ **flash_kwargs,
266
+ )
267
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
268
+
269
+ # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
270
+ # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
271
+ # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
272
+ elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
273
+ batch_size = query_states.size(0)
274
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
275
+ query_states, key_states, value_states, position_ids
276
+ )
277
+
278
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
279
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
280
+
281
+ attn_output = flash_attn_varlen_func(
282
+ query_states,
283
+ key_states,
284
+ value_states,
285
+ cu_seqlens_q=cu_seqlens_q,
286
+ cu_seqlens_k=cu_seqlens_k,
287
+ max_seqlen_q=max_seqlen_in_batch_q,
288
+ max_seqlen_k=max_seqlen_in_batch_k,
289
+ dropout_p=dropout,
290
+ softmax_scale=softmax_scale,
291
+ causal=causal,
292
+ **flash_kwargs,
293
+ )
294
+
295
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
296
+
297
+ else:
298
+ attn_output = flash_attn_func(
299
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
300
+ )
301
+
302
+ return attn_output