gugarosa commited on
Commit
0bbd68a
1 Parent(s): de35f90

Adds support for flash-attn rotary embedding and fused dense layers.

Browse files
Files changed (1) hide show
  1. modeling_mixformer_sequential.py +59 -19
modeling_mixformer_sequential.py CHANGED
@@ -32,6 +32,7 @@
32
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
 
34
  from __future__ import annotations
 
35
 
36
  import math
37
  from typing import Any, Dict, Optional, Tuple, Union
@@ -48,6 +49,18 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
48
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  @dataclass
52
  class InferenceParams:
53
  """Inference parameters passed to model to efficiently calculate
@@ -213,6 +226,7 @@ class RotaryEmbedding(nn.Module):
213
  dim: int,
214
  base: int = 10000,
215
  scale_base: Optional[float] = None,
 
216
  device: Optional[str] = None,
217
  **kwargs,
218
  ) -> None:
@@ -221,15 +235,17 @@ class RotaryEmbedding(nn.Module):
221
  if scale_base is not None:
222
  raise NotImplementedError
223
 
224
- # Generate and save the inverse frequency buffer (non-trainable)
225
  self.dim = dim
226
- self.base = base
227
  self.scale_base = scale_base
 
228
  self.device = device
229
 
230
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
 
231
  self.register_buffer("inv_freq", inv_freq, persistent=False)
232
 
 
233
  scale = (
234
  (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
235
  if scale_base is not None
@@ -243,23 +259,37 @@ class RotaryEmbedding(nn.Module):
243
  self._cos_k_cached = None
244
  self._sin_k_cached = None
245
 
 
 
 
246
  def _update_cos_sin_cache(
247
  self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
248
  ) -> None:
249
- # Re-generate the inverse frequency buffer if it's not fp32
250
- # (for instance if model.half() was called)
251
- if self.inv_freq.dtype != "torch.float32":
252
- self.inv_freq = 1.0 / (
253
- self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
254
- )
255
-
256
- if seqlen > self._seq_len_cached or self._cos_cached.device != device or self._cos_cached.dtype != dtype:
 
257
  self._seq_len_cached = seqlen
258
- t = torch.arange(seqlen, device=device, dtype=torch.float32)
259
 
260
- # Don't do einsum, it converts fp32 to fp16
261
- # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
262
- freqs = torch.outer(t, self.inv_freq.to(device=t.device, dtype=torch.float32))
 
 
 
 
 
 
 
 
 
 
 
263
  if self.scale is None:
264
  self._cos_cached = torch.cos(freqs).to(dtype)
265
  self._sin_cached = torch.sin(freqs).to(dtype)
@@ -269,7 +299,7 @@ class RotaryEmbedding(nn.Module):
269
  ) / self.scale_base
270
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
271
 
272
- # We want the multiplication by scale to happen in fp32
273
  self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
274
  self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
275
  self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
@@ -520,6 +550,8 @@ class MHA(nn.Module):
520
  causal: bool = True,
521
  softmax_scale: Optional[float] = None,
522
  dropout: float = 0.0,
 
 
523
  layer_idx: Optional[int] = None,
524
  return_residual: bool = False,
525
  checkpointing: bool = False,
@@ -532,15 +564,23 @@ class MHA(nn.Module):
532
  rotary_kwargs = {"device": device}
533
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
534
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
535
- self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, **rotary_kwargs)
 
 
 
 
536
 
537
  # MLP
538
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
539
  op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
540
  hidden_size = config.n_embd
541
 
542
- self.Wqkv = nn.Linear(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
543
- self.out_proj = nn.Linear(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
 
 
 
 
544
 
545
  # Attention
546
  self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
 
32
  # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
 
34
  from __future__ import annotations
35
+ import importlib
36
 
37
  import math
38
  from typing import Any, Dict, Optional, Tuple, Union
 
49
  from .configuration_mixformer_sequential import MixFormerSequentialConfig
50
 
51
 
52
+ def _is_flash_attn_available() -> bool:
53
+ return importlib.util.find_spec("flash_attn") is not None
54
+
55
+
56
+ if _is_flash_attn_available():
57
+ from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding
58
+ from flash_attn.ops.fused_dense import FusedDense
59
+ else:
60
+ FlashRotaryEmbedding = None
61
+ FusedDense = None
62
+
63
+
64
  @dataclass
65
  class InferenceParams:
66
  """Inference parameters passed to model to efficiently calculate
 
226
  dim: int,
227
  base: int = 10000,
228
  scale_base: Optional[float] = None,
229
+ pos_idx_in_fp32: bool = True,
230
  device: Optional[str] = None,
231
  **kwargs,
232
  ) -> None:
 
235
  if scale_base is not None:
236
  raise NotImplementedError
237
 
 
238
  self.dim = dim
239
+ self.base = float(base)
240
  self.scale_base = scale_base
241
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
242
  self.device = device
243
 
244
+ # Generate and save the inverse frequency buffer (non-trainable)
245
+ inv_freq = self._compute_inv_freq(device)
246
  self.register_buffer("inv_freq", inv_freq, persistent=False)
247
 
248
+ # Generate and save the scale buffer (non-trainable)
249
  scale = (
250
  (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
251
  if scale_base is not None
 
259
  self._cos_k_cached = None
260
  self._sin_k_cached = None
261
 
262
+ def _compute_inv_freq(self, device: Optional[str] = None) -> torch.FloatTensor:
263
+ return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim))
264
+
265
  def _update_cos_sin_cache(
266
  self, seqlen: int, device: Optional[str] = None, dtype: Optional[torch.dtype] = None
267
  ) -> None:
268
+ # Reset the tables if sequence length has been chaned, if we are on a
269
+ # new device or if we are switching from inference mode to training
270
+ if (
271
+ seqlen > self._seq_len_cached
272
+ or self._cos_cached is None
273
+ or self._cos_cached.device != device
274
+ or self._cos_cached.dtype != dtype
275
+ or (self.training and self._cos_cached.is_inference())
276
+ ):
277
  self._seq_len_cached = seqlen
 
278
 
279
+ # fp32 is preferred since the output of `torch.arange` can be quite large
280
+ # and bf16 would lose a lot of precision
281
+ if self.pos_idx_in_fp32:
282
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
283
+ if self.inv_freq.dtype != torch.float32:
284
+ inv_freq = self._compute_inv_freq(device=device)
285
+ else:
286
+ inv_freq = self.inv_freq
287
+ else:
288
+ t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
289
+ inv_freq = self.inv_freq
290
+
291
+ # `torch.outer` is preferred since `torch.einsum` converts from fp32 to fp16 if used with AMP
292
+ freqs = torch.outer(t, inv_freq)
293
  if self.scale is None:
294
  self._cos_cached = torch.cos(freqs).to(dtype)
295
  self._sin_cached = torch.sin(freqs).to(dtype)
 
299
  ) / self.scale_base
300
  scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
301
 
302
+ # Force the scale multiplication to happen in fp32
303
  self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
304
  self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
305
  self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
 
550
  causal: bool = True,
551
  softmax_scale: Optional[float] = None,
552
  dropout: float = 0.0,
553
+ flash_rotary: bool = True,
554
+ fused_dense: bool = True,
555
  layer_idx: Optional[int] = None,
556
  return_residual: bool = False,
557
  checkpointing: bool = False,
 
564
  rotary_kwargs = {"device": device}
565
  if rotary_emb_scale_base is not None and rotary_emb_scale_base > 0.0:
566
  rotary_kwargs["scale_base"] = rotary_emb_scale_base
567
+
568
+ rotary_cls = FlashRotaryEmbedding if flash_rotary else RotaryEmbedding
569
+ if rotary_cls is None:
570
+ rotary_cls = RotaryEmbedding
571
+ self.rotary_emb = rotary_cls(self.rotary_emb_dim, **rotary_kwargs)
572
 
573
  # MLP
574
  self.n_head, self.n_head_kv, self.head_dim = _find_mha_dims(config, n_head=n_head, n_head_kv=n_head_kv, head_dim=head_dim)
575
  op_size = self.head_dim * (self.n_head + 2 * self.n_head_kv)
576
  hidden_size = config.n_embd
577
 
578
+ linear_cls = FusedDense if fused_dense else nn.Linear
579
+ if linear_cls is None:
580
+ linear_cls = nn.Linear
581
+
582
+ self.Wqkv = linear_cls(hidden_size, op_size, bias=bias, device=device, dtype=dtype)
583
+ self.out_proj = linear_cls(hidden_size, hidden_size, bias=bias, device=device, dtype=dtype)
584
 
585
  # Attention
586
  self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)