jupyterjazz commited on
Commit
6e55444
1 Parent(s): 0f0bed6

style: removing unused files, black, isort

Browse files

Signed-off-by: jupyterjazz <[email protected]>

Files changed (10) hide show
  1. block.py +5 -4
  2. embedding.py +27 -13
  3. mha.py +101 -42
  4. mlp.py +33 -15
  5. modeling_lora.py +30 -18
  6. modeling_xlm_roberta.py +116 -194
  7. modeling_xlm_roberta_for_glue.py +0 -109
  8. rotary.py +43 -16
  9. stochastic_depth.py +1 -1
  10. xlm_padding.py +24 -10
block.py CHANGED
@@ -8,15 +8,14 @@ from typing import Optional
8
 
9
  import torch
10
  import torch.nn as nn
11
- import torch.nn.functional as F
12
  from torch import Tensor
13
 
14
- from .stochastic_depth import StochasticDepth
15
  from .mha import MHA
16
  from .mlp import Mlp
 
17
 
18
  try:
19
- from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
20
  except ImportError:
21
  layer_norm_fn, RMSNorm = None, None
22
 
@@ -233,7 +232,9 @@ class Block(nn.Module):
233
  is_rms_norm=isinstance(self.norm1, RMSNorm),
234
  )
235
  if not isinstance(self.mlp, nn.Identity):
236
- mlp_out = self.mlp(hidden_states, adapter_mask=mixer_kwargs.get('adapter_mask'))
 
 
237
  if self.return_residual: # mlp out is actually a pair here
238
  mlp_out, hidden_states = mlp_out
239
  if not self.fused_dropout_add_ln:
 
8
 
9
  import torch
10
  import torch.nn as nn
 
11
  from torch import Tensor
12
 
 
13
  from .mha import MHA
14
  from .mlp import Mlp
15
+ from .stochastic_depth import StochasticDepth
16
 
17
  try:
18
+ from flash_attn.ops.triton.layer_norm import RMSNorm, layer_norm_fn
19
  except ImportError:
20
  layer_norm_fn, RMSNorm = None, None
21
 
 
232
  is_rms_norm=isinstance(self.norm1, RMSNorm),
233
  )
234
  if not isinstance(self.mlp, nn.Identity):
235
+ mlp_out = self.mlp(
236
+ hidden_states, adapter_mask=mixer_kwargs.get("adapter_mask")
237
+ )
238
  if self.return_residual: # mlp out is actually a pair here
239
  mlp_out, hidden_states = mlp_out
240
  if not self.fused_dropout_add_ln:
embedding.py CHANGED
@@ -5,10 +5,8 @@
5
 
6
  import torch
7
  import torch.nn as nn
8
- from einops import rearrange
9
- from torch import Tensor
10
-
11
- from transformers.models.xlm_roberta.modeling_xlm_roberta import create_position_ids_from_input_ids
12
 
13
 
14
  class XLMRobertaEmbeddings(nn.Module):
@@ -38,20 +36,29 @@ class XLMRobertaEmbeddings(nn.Module):
38
  max_position_embeddings, embed_dim, **factory_kwargs
39
  )
40
  if self.type_vocab_size > 0:
41
- self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs)
 
 
42
 
43
- def forward(self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None):
 
 
44
  """
45
  input_ids: (batch, seqlen)
46
  position_ids: (batch, seqlen)
47
  token_type_ids: (batch, seqlen)
 
48
  """
49
  batch_size, seqlen = input_ids.shape
50
  if adapter_mask is not None:
51
  unique_tasks = torch.unique(adapter_mask)
52
  embedding_dtype = next(self.word_embeddings.parameters()).dtype
53
- embeddings = torch.empty(*input_ids.shape, self.word_embeddings.embedding_dim,
54
- dtype=embedding_dtype, device=input_ids.device)
 
 
 
 
55
  for task_id in unique_tasks:
56
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
57
  task_input_ids = input_ids[task_indices]
@@ -61,20 +68,27 @@ class XLMRobertaEmbeddings(nn.Module):
61
  embeddings = self.word_embeddings(input_ids)
62
  if self.max_position_embeddings > 0:
63
  if position_ids is None:
64
- position_ids = create_position_ids_from_input_ids(input_ids, padding_idx=self.word_embeddings.padding_idx).to(input_ids.device)
65
- # position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device)
 
66
  position_embeddings = self.position_embeddings(position_ids)
67
  embeddings = embeddings + position_embeddings
68
  if self.type_vocab_size > 0:
69
  if token_type_ids is None:
70
- token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device)
 
 
71
 
72
  if adapter_mask is not None:
73
  unique_tasks = torch.unique(adapter_mask)
74
  for task_id in unique_tasks:
75
- task_token_type_embeddings = self.token_type_embeddings(token_type_ids, task_id=task_id)
 
 
76
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
77
- embeddings[task_indices] = embeddings[task_indices] + task_token_type_embeddings
 
 
78
  else:
79
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
80
  embeddings = embeddings + token_type_embeddings
 
5
 
6
  import torch
7
  import torch.nn as nn
8
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import \
9
+ create_position_ids_from_input_ids
 
 
10
 
11
 
12
  class XLMRobertaEmbeddings(nn.Module):
 
36
  max_position_embeddings, embed_dim, **factory_kwargs
37
  )
38
  if self.type_vocab_size > 0:
39
+ self.token_type_embeddings = nn.Embedding(
40
+ type_vocab_size, embed_dim, **factory_kwargs
41
+ )
42
 
43
+ def forward(
44
+ self, input_ids, position_ids=None, token_type_ids=None, adapter_mask=None
45
+ ):
46
  """
47
  input_ids: (batch, seqlen)
48
  position_ids: (batch, seqlen)
49
  token_type_ids: (batch, seqlen)
50
+ adapter_mask: (batch, 1)
51
  """
52
  batch_size, seqlen = input_ids.shape
53
  if adapter_mask is not None:
54
  unique_tasks = torch.unique(adapter_mask)
55
  embedding_dtype = next(self.word_embeddings.parameters()).dtype
56
+ embeddings = torch.empty(
57
+ *input_ids.shape,
58
+ self.word_embeddings.embedding_dim,
59
+ dtype=embedding_dtype,
60
+ device=input_ids.device
61
+ )
62
  for task_id in unique_tasks:
63
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
64
  task_input_ids = input_ids[task_indices]
 
68
  embeddings = self.word_embeddings(input_ids)
69
  if self.max_position_embeddings > 0:
70
  if position_ids is None:
71
+ position_ids = create_position_ids_from_input_ids(
72
+ input_ids, padding_idx=self.word_embeddings.padding_idx
73
+ ).to(input_ids.device)
74
  position_embeddings = self.position_embeddings(position_ids)
75
  embeddings = embeddings + position_embeddings
76
  if self.type_vocab_size > 0:
77
  if token_type_ids is None:
78
+ token_type_ids = torch.zeros(
79
+ seqlen, dtype=torch.long, device=input_ids.device
80
+ )
81
 
82
  if adapter_mask is not None:
83
  unique_tasks = torch.unique(adapter_mask)
84
  for task_id in unique_tasks:
85
+ task_token_type_embeddings = self.token_type_embeddings(
86
+ token_type_ids, task_id=task_id
87
+ )
88
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
89
+ embeddings[task_indices] = (
90
+ embeddings[task_indices] + task_token_type_embeddings
91
+ )
92
  else:
93
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
94
  embeddings = embeddings + token_type_embeddings
mha.py CHANGED
@@ -1,5 +1,8 @@
 
 
 
 
1
  # Copyright (c) 2023, Tri Dao.
2
- # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
3
 
4
  import math
5
  from functools import partial
@@ -9,20 +12,19 @@ import torch.nn as nn
9
  from einops import rearrange, repeat
10
 
11
  try:
12
- from flash_attn import (
13
- flash_attn_kvpacked_func,
14
- flash_attn_qkvpacked_func,
15
- flash_attn_varlen_kvpacked_func,
16
- flash_attn_varlen_qkvpacked_func,
17
- flash_attn_with_kvcache,
18
- )
19
  except ImportError:
20
  flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
21
  flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
22
  flash_attn_with_kvcache = None
23
 
24
  try:
25
- from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
 
26
  except ImportError:
27
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
28
 
@@ -42,7 +44,9 @@ def get_alibi_slopes(nheads):
42
  closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
43
  return (
44
  get_slopes_power_of_2(closest_power_of_2)
45
- + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2]
 
 
46
  )
47
 
48
 
@@ -67,7 +71,9 @@ class FlashSelfAttention(nn.Module):
67
  deterministic=False,
68
  ):
69
  super().__init__()
70
- assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
 
 
71
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
72
  self.causal = causal
73
  self.softmax_scale = softmax_scale
@@ -147,7 +153,9 @@ class FlashCrossAttention(nn.Module):
147
  deterministic=False,
148
  ):
149
  super().__init__()
150
- assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
 
 
151
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
152
  self.causal = causal
153
  self.softmax_scale = softmax_scale
@@ -313,7 +321,10 @@ class CrossAttention(nn.Module):
313
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
314
  if key_padding_mask is not None:
315
  padding_mask = torch.full(
316
- (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
 
 
 
317
  )
318
  padding_mask.masked_fill_(key_padding_mask, 0.0)
319
  # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
@@ -425,20 +436,26 @@ class MHA(nn.Module):
425
  else:
426
  alibi_slopes = None
427
  if window_size != (-1, -1):
428
- assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn"
 
 
429
 
430
  self.num_heads = num_heads
431
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
432
  assert (
433
  self.num_heads % self.num_heads_kv == 0
434
  ), "num_heads must be divisible by num_heads_kv"
435
- assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
 
 
436
  self.head_dim = self.embed_dim // num_heads
437
  qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
438
  kv_dim = 2 * self.head_dim * self.num_heads_kv
439
 
440
  if self.rotary_emb_dim > 0:
441
- assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
 
 
442
  assert RotaryEmbedding is not None, "rotary_emb is not installed"
443
  self.rotary_emb = RotaryEmbedding(
444
  self.rotary_emb_dim,
@@ -453,23 +470,33 @@ class MHA(nn.Module):
453
 
454
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
455
  linear_resid_cls = (
456
- LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
 
 
457
  )
458
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
459
  inner_attn_cls = (
460
- partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size)
 
 
461
  if use_flash_attn
462
  else SelfAttention
463
  )
464
  inner_cross_attn_cls = (
465
- partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size)
 
 
466
  if use_flash_attn
467
  else CrossAttention
468
  )
469
  if not self.cross_attn:
470
- self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs)
 
 
471
  else:
472
- self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs)
 
 
473
  self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
474
  if self.dwconv:
475
  if self.num_heads_kv == self.num_heads:
@@ -480,7 +507,9 @@ class MHA(nn.Module):
480
  self.dwconv_q = nn.Conv1d(
481
  embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
482
  )
483
- self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
 
 
484
  self.inner_attn = inner_attn_cls(
485
  causal=causal,
486
  softmax_scale=softmax_scale,
@@ -489,7 +518,9 @@ class MHA(nn.Module):
489
  self.inner_cross_attn = inner_cross_attn_cls(
490
  causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
491
  )
492
- self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
 
 
493
 
494
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
495
  dtype = self.out_proj.weight.dtype if dtype is None else dtype
@@ -507,7 +538,9 @@ class MHA(nn.Module):
507
  def _update_kv_cache(self, kv, inference_params):
508
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
509
  assert not self.dwconv, "Generation does not support dwconv yet"
510
- assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
 
 
511
  return _update_kv_cache(kv, inference_params, self.layer_idx)
512
 
513
  def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
@@ -523,7 +556,10 @@ class MHA(nn.Module):
523
  self.rotary_emb._update_cos_sin_cache(
524
  inference_params.max_seqlen, device=q.device, dtype=q.dtype
525
  )
526
- rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached
 
 
 
527
  else:
528
  rotary_cos, rotary_sin = None, None
529
  batch = q.shape[0]
@@ -545,7 +581,9 @@ class MHA(nn.Module):
545
  cache_seqlens=cache_seqlens,
546
  softmax_scale=self.inner_cross_attn.softmax_scale,
547
  causal=self.inner_cross_attn.causal,
548
- rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
 
 
549
  alibi_slopes=alibi_slopes,
550
  )
551
  return context
@@ -640,40 +678,49 @@ class MHA(nn.Module):
640
  )
641
  )
642
  rotary_max_seqlen = (
643
- inference_params.max_sequence_len if inference_params is not None else max_seqlen
 
 
644
  )
645
- batch, seqlen = x.shape[:2]
646
- lora_kwargs = {}
647
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
648
  assert x_kv is None and mixer_subset is None
649
 
650
  if adapter_mask is not None:
651
  unique_tasks = torch.unique(adapter_mask)
652
  qkv_dtype = next(self.Wqkv.parameters()).dtype
653
- qkv = torch.empty(*x.shape[:-1], self.Wqkv.out_features,
654
- dtype=qkv_dtype, device=x.device)
 
 
 
 
655
  for task_id in unique_tasks:
656
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
657
  task_tensor = x[task_indices]
658
  if not self.return_residual:
659
  task_qkv = self.Wqkv(task_tensor, task_id=task_id)
660
  else:
661
- task_qkv, _ = self.Wqkv(task_tensor, task_id=task_id, residual=True)
 
 
662
  qkv[task_indices] = task_qkv
663
  else:
664
  if not self.return_residual:
665
  qkv = self.Wqkv(x)
666
  else:
667
- if hasattr(self.Wqkv, 'parametrizations'):
668
  qkv, x = self.Wqkv(x, residual=True)
669
  else:
670
  qkv, x = self.Wqkv(x)
671
 
672
  if self.dwconv:
673
  qkv = rearrange(
674
- self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
675
  ).contiguous()
676
- qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
 
 
677
  if (
678
  inference_params is None
679
  or inference_params.seqlen_offset == 0
@@ -691,7 +738,9 @@ class MHA(nn.Module):
691
  if not self.checkpointing:
692
  context = self.inner_attn(qkv, **kwargs)
693
  else:
694
- context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
 
 
695
  else:
696
  context = self._update_kvcache_attention(
697
  qkv[:, :, 0], qkv[:, :, 1:], inference_params
@@ -720,13 +769,17 @@ class MHA(nn.Module):
720
  q = qkv[..., : self.num_heads * self.head_dim]
721
  kv = qkv[..., self.num_heads * self.head_dim :]
722
  q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
723
- kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
 
 
724
  if self.dwconv:
725
  q = rearrange(
726
- self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
727
  ).contiguous()
728
  kv = rearrange(
729
- self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
 
730
  ).contiguous()
731
  if (
732
  inference_params is None
@@ -752,14 +805,20 @@ class MHA(nn.Module):
752
  else:
753
  context = self._update_kvcache_attention(q, kv, inference_params)
754
  else:
755
- context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params)
 
 
756
 
757
  inp = rearrange(context, "... h d -> ... (h d)")
758
  if adapter_mask is not None:
759
  unique_tasks = torch.unique(adapter_mask)
760
  out_dtype = next(self.out_proj.parameters()).dtype
761
- out = torch.empty(*inp.shape[:-1], self.out_proj.out_features,
762
- dtype=out_dtype, device=inp.device)
 
 
 
 
763
  for task_id in unique_tasks:
764
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
765
  task_tensor = inp[task_indices]
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py
2
+ # Commit id: 6bbc532388e61185a92e2a563126739967b4c8c5
3
+ # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
+
5
  # Copyright (c) 2023, Tri Dao.
 
6
 
7
  import math
8
  from functools import partial
 
12
  from einops import rearrange, repeat
13
 
14
  try:
15
+ from flash_attn import (flash_attn_kvpacked_func,
16
+ flash_attn_qkvpacked_func,
17
+ flash_attn_varlen_kvpacked_func,
18
+ flash_attn_varlen_qkvpacked_func,
19
+ flash_attn_with_kvcache)
 
 
20
  except ImportError:
21
  flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
22
  flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
23
  flash_attn_with_kvcache = None
24
 
25
  try:
26
+ from flash_attn.ops.fused_dense import (ColumnParallelLinear, FusedDense,
27
+ RowParallelLinear)
28
  except ImportError:
29
  FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
30
 
 
44
  closest_power_of_2 = 2 ** math.floor(math.log2(nheads))
45
  return (
46
  get_slopes_power_of_2(closest_power_of_2)
47
+ + get_alibi_slopes(2 * closest_power_of_2)[0::2][
48
+ : nheads - closest_power_of_2
49
+ ]
50
  )
51
 
52
 
 
71
  deterministic=False,
72
  ):
73
  super().__init__()
74
+ assert (
75
+ flash_attn_varlen_qkvpacked_func is not None
76
+ ), "FlashAttention is not installed"
77
  assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
78
  self.causal = causal
79
  self.softmax_scale = softmax_scale
 
153
  deterministic=False,
154
  ):
155
  super().__init__()
156
+ assert (
157
+ flash_attn_varlen_kvpacked_func is not None
158
+ ), "FlashAttention is not installed"
159
  assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
160
  self.causal = causal
161
  self.softmax_scale = softmax_scale
 
321
  scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
322
  if key_padding_mask is not None:
323
  padding_mask = torch.full(
324
+ (batch_size, seqlen_k),
325
+ -10000.0,
326
+ dtype=scores.dtype,
327
+ device=scores.device,
328
  )
329
  padding_mask.masked_fill_(key_padding_mask, 0.0)
330
  # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
 
436
  else:
437
  alibi_slopes = None
438
  if window_size != (-1, -1):
439
+ assert (
440
+ use_flash_attn
441
+ ), "Local (sliding window) attention code path requires flash_attn"
442
 
443
  self.num_heads = num_heads
444
  self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
445
  assert (
446
  self.num_heads % self.num_heads_kv == 0
447
  ), "num_heads must be divisible by num_heads_kv"
448
+ assert (
449
+ self.embed_dim % num_heads == 0
450
+ ), "embed_dim must be divisible by num_heads"
451
  self.head_dim = self.embed_dim // num_heads
452
  qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
453
  kv_dim = 2 * self.head_dim * self.num_heads_kv
454
 
455
  if self.rotary_emb_dim > 0:
456
+ assert (
457
+ not cross_attn
458
+ ), "MHA with rotary embedding does not support cross-attention yet"
459
  assert RotaryEmbedding is not None, "rotary_emb is not installed"
460
  self.rotary_emb = RotaryEmbedding(
461
  self.rotary_emb_dim,
 
470
 
471
  linear_cls = nn.Linear if not fused_bias_fc else FusedDense
472
  linear_resid_cls = (
473
+ LinearResidual
474
+ if not fused_bias_fc
475
+ else partial(FusedDense, return_residual=True)
476
  )
477
  wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
478
  inner_attn_cls = (
479
+ partial(
480
+ FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size
481
+ )
482
  if use_flash_attn
483
  else SelfAttention
484
  )
485
  inner_cross_attn_cls = (
486
+ partial(
487
+ FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size
488
+ )
489
  if use_flash_attn
490
  else CrossAttention
491
  )
492
  if not self.cross_attn:
493
+ self.Wqkv = wqkv_cls(
494
+ embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs
495
+ )
496
  else:
497
+ self.Wq = linear_cls(
498
+ embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs
499
+ )
500
  self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
501
  if self.dwconv:
502
  if self.num_heads_kv == self.num_heads:
 
507
  self.dwconv_q = nn.Conv1d(
508
  embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
509
  )
510
+ self.dwconv_kv = nn.Conv1d(
511
+ kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim
512
+ )
513
  self.inner_attn = inner_attn_cls(
514
  causal=causal,
515
  softmax_scale=softmax_scale,
 
518
  self.inner_cross_attn = inner_cross_attn_cls(
519
  causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
520
  )
521
+ self.out_proj = linear_cls(
522
+ embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs
523
+ )
524
 
525
  def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None):
526
  dtype = self.out_proj.weight.dtype if dtype is None else dtype
 
538
  def _update_kv_cache(self, kv, inference_params):
539
  """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
540
  assert not self.dwconv, "Generation does not support dwconv yet"
541
+ assert (
542
+ self.layer_idx is not None
543
+ ), "Generation requires layer_idx in the constructor"
544
  return _update_kv_cache(kv, inference_params, self.layer_idx)
545
 
546
  def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params):
 
556
  self.rotary_emb._update_cos_sin_cache(
557
  inference_params.max_seqlen, device=q.device, dtype=q.dtype
558
  )
559
+ rotary_cos, rotary_sin = (
560
+ self.rotary_emb._cos_cached,
561
+ self.rotary_emb._sin_cached,
562
+ )
563
  else:
564
  rotary_cos, rotary_sin = None, None
565
  batch = q.shape[0]
 
581
  cache_seqlens=cache_seqlens,
582
  softmax_scale=self.inner_cross_attn.softmax_scale,
583
  causal=self.inner_cross_attn.causal,
584
+ rotary_interleaved=(
585
+ self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False
586
+ ),
587
  alibi_slopes=alibi_slopes,
588
  )
589
  return context
 
678
  )
679
  )
680
  rotary_max_seqlen = (
681
+ inference_params.max_sequence_len
682
+ if inference_params is not None
683
+ else max_seqlen
684
  )
 
 
685
  if not self.cross_attn and self.num_heads_kv == self.num_heads:
686
  assert x_kv is None and mixer_subset is None
687
 
688
  if adapter_mask is not None:
689
  unique_tasks = torch.unique(adapter_mask)
690
  qkv_dtype = next(self.Wqkv.parameters()).dtype
691
+ qkv = torch.empty(
692
+ *x.shape[:-1],
693
+ self.Wqkv.out_features,
694
+ dtype=qkv_dtype,
695
+ device=x.device,
696
+ )
697
  for task_id in unique_tasks:
698
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
699
  task_tensor = x[task_indices]
700
  if not self.return_residual:
701
  task_qkv = self.Wqkv(task_tensor, task_id=task_id)
702
  else:
703
+ task_qkv, _ = self.Wqkv(
704
+ task_tensor, task_id=task_id, residual=True
705
+ )
706
  qkv[task_indices] = task_qkv
707
  else:
708
  if not self.return_residual:
709
  qkv = self.Wqkv(x)
710
  else:
711
+ if hasattr(self.Wqkv, "parametrizations"):
712
  qkv, x = self.Wqkv(x, residual=True)
713
  else:
714
  qkv, x = self.Wqkv(x)
715
 
716
  if self.dwconv:
717
  qkv = rearrange(
718
+ self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2],
719
+ "b d s -> b s d",
720
  ).contiguous()
721
+ qkv = rearrange(
722
+ qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim
723
+ )
724
  if (
725
  inference_params is None
726
  or inference_params.seqlen_offset == 0
 
738
  if not self.checkpointing:
739
  context = self.inner_attn(qkv, **kwargs)
740
  else:
741
+ context = torch.utils.checkpoint.checkpoint(
742
+ self.inner_attn, qkv, **kwargs
743
+ )
744
  else:
745
  context = self._update_kvcache_attention(
746
  qkv[:, :, 0], qkv[:, :, 1:], inference_params
 
769
  q = qkv[..., : self.num_heads * self.head_dim]
770
  kv = qkv[..., self.num_heads * self.head_dim :]
771
  q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
772
+ kv = rearrange(
773
+ kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim
774
+ )
775
  if self.dwconv:
776
  q = rearrange(
777
+ self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2],
778
+ "b d s -> b s d",
779
  ).contiguous()
780
  kv = rearrange(
781
+ self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2],
782
+ "b d s -> b s d",
783
  ).contiguous()
784
  if (
785
  inference_params is None
 
805
  else:
806
  context = self._update_kvcache_attention(q, kv, inference_params)
807
  else:
808
+ context = self._apply_rotary_update_kvcache_attention(
809
+ q, kv, inference_params
810
+ )
811
 
812
  inp = rearrange(context, "... h d -> ... (h d)")
813
  if adapter_mask is not None:
814
  unique_tasks = torch.unique(adapter_mask)
815
  out_dtype = next(self.out_proj.parameters()).dtype
816
+ out = torch.empty(
817
+ *inp.shape[:-1],
818
+ self.out_proj.out_features,
819
+ dtype=out_dtype,
820
+ device=inp.device,
821
+ )
822
  for task_id in unique_tasks:
823
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
824
  task_tensor = inp[task_indices]
mlp.py CHANGED
@@ -8,14 +8,14 @@ import torch.nn as nn
8
  import torch.nn.functional as F
9
  from torch.distributed import ProcessGroup
10
 
11
-
12
  try:
13
  from flash_attn.ops.activations import swiglu
14
  except ImportError:
15
  swiglu = None
16
 
17
  try:
18
- from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
 
19
  except ImportError:
20
  ColumnParallelLinear, RowParallelLinear = None, None
21
 
@@ -41,18 +41,23 @@ class Mlp(nn.Module):
41
  factory_kwargs = {"device": device, "dtype": dtype}
42
  super().__init__()
43
  out_features = out_features if out_features is not None else in_features
44
- hidden_features = hidden_features if hidden_features is not None else in_features * 4
 
 
45
  self.return_residual = return_residual
46
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
47
  self.activation = activation
48
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
 
 
49
 
50
  def forward(self, x, adapter_mask=None):
51
  if adapter_mask is not None:
52
  unique_tasks = torch.unique(adapter_mask)
53
  fc1_dtype = next(self.fc1.parameters()).dtype
54
- y = torch.empty(*x.shape[:-1], self.fc1.out_features,
55
- dtype=fc1_dtype, device=x.device)
 
56
  for task_id in unique_tasks:
57
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
58
  task_tensor = x[task_indices]
@@ -66,8 +71,9 @@ class Mlp(nn.Module):
66
  if adapter_mask is not None:
67
  unique_tasks = torch.unique(adapter_mask)
68
  fc2_dtype = next(self.fc2.parameters()).dtype
69
- out = torch.empty(*y.shape[:-1], self.fc2.out_features,
70
- dtype=fc2_dtype, device=y.device)
 
71
  for task_id in unique_tasks:
72
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
73
  task_tensor = y[task_indices]
@@ -98,7 +104,9 @@ class ParallelMLP(nn.Module):
98
  assert ColumnParallelLinear is not None, "Need to install fused_dense"
99
  assert RowParallelLinear is not None, "Need to install fused_dense"
100
  out_features = out_features if out_features is not None else in_features
101
- hidden_features = hidden_features if hidden_features is not None else in_features * 4
 
 
102
  self.fc1 = ColumnParallelLinear(
103
  in_features,
104
  hidden_features,
@@ -144,17 +152,25 @@ class GatedMlp(nn.Module):
144
  hidden_features = (
145
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
146
  )
147
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
 
 
148
  self.return_residual = return_residual
149
- self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs)
 
 
150
  self.activation = activation
151
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
 
 
152
 
153
  def forward(self, x):
154
  y = self.fc1(x)
155
  if self.activation == F.sigmoid: # Special case for GLU
156
  y = F.glu(y, dim=-1)
157
- elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU
 
 
158
  y, gate = y.chunk(2, dim=-1)
159
  y = swiglu(gate, y)
160
  else:
@@ -187,7 +203,9 @@ class ParallelGatedMlp(nn.Module):
187
  hidden_features = (
188
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
189
  )
190
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
 
 
191
  if ColumnParallelLinear is None or RowParallelLinear is None:
192
  raise ImportError("fused_dense is not installed")
193
  self.fc1 = ColumnParallelLinear(
@@ -216,4 +234,4 @@ class ParallelGatedMlp(nn.Module):
216
  y, gate = y.chunk(2, dim=-1)
217
  y = y * self.activation(gate)
218
  y = self.fc2(y)
219
- return y
 
8
  import torch.nn.functional as F
9
  from torch.distributed import ProcessGroup
10
 
 
11
  try:
12
  from flash_attn.ops.activations import swiglu
13
  except ImportError:
14
  swiglu = None
15
 
16
  try:
17
+ from flash_attn.ops.fused_dense import (ColumnParallelLinear,
18
+ RowParallelLinear)
19
  except ImportError:
20
  ColumnParallelLinear, RowParallelLinear = None, None
21
 
 
41
  factory_kwargs = {"device": device, "dtype": dtype}
42
  super().__init__()
43
  out_features = out_features if out_features is not None else in_features
44
+ hidden_features = (
45
+ hidden_features if hidden_features is not None else in_features * 4
46
+ )
47
  self.return_residual = return_residual
48
  self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
49
  self.activation = activation
50
+ self.fc2 = nn.Linear(
51
+ hidden_features, out_features, bias=bias2, **factory_kwargs
52
+ )
53
 
54
  def forward(self, x, adapter_mask=None):
55
  if adapter_mask is not None:
56
  unique_tasks = torch.unique(adapter_mask)
57
  fc1_dtype = next(self.fc1.parameters()).dtype
58
+ y = torch.empty(
59
+ *x.shape[:-1], self.fc1.out_features, dtype=fc1_dtype, device=x.device
60
+ )
61
  for task_id in unique_tasks:
62
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
63
  task_tensor = x[task_indices]
 
71
  if adapter_mask is not None:
72
  unique_tasks = torch.unique(adapter_mask)
73
  fc2_dtype = next(self.fc2.parameters()).dtype
74
+ out = torch.empty(
75
+ *y.shape[:-1], self.fc2.out_features, dtype=fc2_dtype, device=y.device
76
+ )
77
  for task_id in unique_tasks:
78
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
79
  task_tensor = y[task_indices]
 
104
  assert ColumnParallelLinear is not None, "Need to install fused_dense"
105
  assert RowParallelLinear is not None, "Need to install fused_dense"
106
  out_features = out_features if out_features is not None else in_features
107
+ hidden_features = (
108
+ hidden_features if hidden_features is not None else in_features * 4
109
+ )
110
  self.fc1 = ColumnParallelLinear(
111
  in_features,
112
  hidden_features,
 
152
  hidden_features = (
153
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
154
  )
155
+ hidden_features = (
156
+ (hidden_features + multiple_of - 1) // multiple_of * multiple_of
157
+ )
158
  self.return_residual = return_residual
159
+ self.fc1 = nn.Linear(
160
+ in_features, 2 * hidden_features, bias=bias1, **factory_kwargs
161
+ )
162
  self.activation = activation
163
+ self.fc2 = nn.Linear(
164
+ hidden_features, out_features, bias=bias2, **factory_kwargs
165
+ )
166
 
167
  def forward(self, x):
168
  y = self.fc1(x)
169
  if self.activation == F.sigmoid: # Special case for GLU
170
  y = F.glu(y, dim=-1)
171
+ elif (
172
+ self.activation == F.silu and swiglu is not None
173
+ ): # Special case for SwiGLU
174
  y, gate = y.chunk(2, dim=-1)
175
  y = swiglu(gate, y)
176
  else:
 
203
  hidden_features = (
204
  hidden_features if hidden_features is not None else int(8 * in_features / 3)
205
  )
206
+ hidden_features = (
207
+ (hidden_features + multiple_of - 1) // multiple_of * multiple_of
208
+ )
209
  if ColumnParallelLinear is None or RowParallelLinear is None:
210
  raise ImportError("fused_dense is not installed")
211
  self.fc1 = ColumnParallelLinear(
 
234
  y, gate = y.chunk(2, dim=-1)
235
  y = y * self.activation(gate)
236
  y = self.fc2(y)
237
+ return y
modeling_lora.py CHANGED
@@ -1,6 +1,5 @@
1
  import math
2
  import os
3
- import warnings
4
  from functools import partial
5
  from typing import Iterator, List, Optional, Tuple, Union
6
 
@@ -12,7 +11,8 @@ from torch.nn import Parameter
12
  from torch.nn import functional as F
13
  from transformers import PretrainedConfig
14
 
15
- from .modeling_xlm_roberta import XLMRobertaFlashConfig, XLMRobertaModel, XLMRobertaPreTrainedModel
 
16
 
17
 
18
  def initialized_weights(
@@ -177,7 +177,9 @@ class LoRAParametrization(nn.Module):
177
 
178
  def new_forward(self, input, task_id=None, residual=False):
179
  if task_id is not None:
180
- weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
 
 
181
  else:
182
  weights = self.weight
183
 
@@ -204,13 +206,21 @@ class LoRAParametrization(nn.Module):
204
 
205
  def new_forward(self, input, task_id=None):
206
  if task_id is not None:
207
- weights = self.parametrizations.weight[0].lora_forward(self.weight, current_task=task_id)
 
 
208
  else:
209
  weights = self.weight
210
 
211
  out = F.embedding(
212
- input, weights, self.padding_idx, self.max_norm,
213
- self.norm_type, self.scale_grad_by_freq, self.sparse)
 
 
 
 
 
 
214
 
215
  return out
216
 
@@ -219,9 +229,7 @@ class LoRAParametrization(nn.Module):
219
 
220
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
221
  def __init__(
222
- self,
223
- config: XLMRobertaFlashConfig,
224
- roberta: Optional[XLMRobertaModel] = None
225
  ):
226
  super().__init__(config)
227
  if roberta is None:
@@ -235,7 +243,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
235
  or len(self._lora_adaptations) < 1
236
  ):
237
  raise ValueError(
238
- f'`lora_adaptations` must be a list and contain at least one element'
239
  )
240
  self._lora_prompts = config.lora_prompts
241
  if (
@@ -244,9 +252,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
244
  or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
245
  ):
246
  raise ValueError(
247
- f'`lora_prompts` must be a dict and contain the same number of elements '
248
- f'as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`.'
249
- )
250
  self._adaptation_map = {
251
  name: idx for idx, name in enumerate(self._lora_adaptations)
252
  }
@@ -261,7 +269,6 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
261
  )
262
  self.main_params_trainable = config.lora_main_params_trainable
263
 
264
-
265
  @property
266
  def rotary_emb_base(self):
267
  return self.roberta.rotary_emb_base
@@ -305,13 +312,14 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
305
  config = XLMRobertaFlashConfig.from_pretrained(
306
  pretrained_model_name_or_path, *model_args, **kwargs
307
  )
308
-
309
  if config.load_trained_adapters:
310
  return super().from_pretrained(
311
  pretrained_model_name_or_path, *model_args, **kwargs
312
  )
313
  else:
314
- roberta = XLMRobertaModel.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
 
 
315
  return cls(config, roberta=roberta)
316
 
317
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
@@ -367,5 +375,9 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
367
  if task_type:
368
  task_id = self._adaptation_map[task_type]
369
  num_examples = 1 if isinstance(sentences, str) else len(sentences)
370
- adapter_mask = torch.full((num_examples,), task_id, dtype=torch.int32, device=self.device)
371
- return self.roberta.encode(sentences, *args, adapter_mask=adapter_mask, **kwargs)
 
 
 
 
 
1
  import math
2
  import os
 
3
  from functools import partial
4
  from typing import Iterator, List, Optional, Tuple, Union
5
 
 
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
+ from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
15
+ XLMRobertaPreTrainedModel)
16
 
17
 
18
  def initialized_weights(
 
177
 
178
  def new_forward(self, input, task_id=None, residual=False):
179
  if task_id is not None:
180
+ weights = self.parametrizations.weight[0].lora_forward(
181
+ self.weight, current_task=task_id
182
+ )
183
  else:
184
  weights = self.weight
185
 
 
206
 
207
  def new_forward(self, input, task_id=None):
208
  if task_id is not None:
209
+ weights = self.parametrizations.weight[0].lora_forward(
210
+ self.weight, current_task=task_id
211
+ )
212
  else:
213
  weights = self.weight
214
 
215
  out = F.embedding(
216
+ input,
217
+ weights,
218
+ self.padding_idx,
219
+ self.max_norm,
220
+ self.norm_type,
221
+ self.scale_grad_by_freq,
222
+ self.sparse,
223
+ )
224
 
225
  return out
226
 
 
229
 
230
  class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
231
  def __init__(
232
+ self, config: XLMRobertaFlashConfig, roberta: Optional[XLMRobertaModel] = None
 
 
233
  ):
234
  super().__init__(config)
235
  if roberta is None:
 
243
  or len(self._lora_adaptations) < 1
244
  ):
245
  raise ValueError(
246
+ f"`lora_adaptations` must be a list and contain at least one element"
247
  )
248
  self._lora_prompts = config.lora_prompts
249
  if (
 
252
  or not all([v in self._lora_adaptations for v in self._lora_prompts.keys()])
253
  ):
254
  raise ValueError(
255
+ f"`lora_prompts` must be a dict and contain the same number of elements "
256
+ f"as `lora_adaptations` with all keys in `lora_prompts` present in `lora_adaptations`."
257
+ )
258
  self._adaptation_map = {
259
  name: idx for idx, name in enumerate(self._lora_adaptations)
260
  }
 
269
  )
270
  self.main_params_trainable = config.lora_main_params_trainable
271
 
 
272
  @property
273
  def rotary_emb_base(self):
274
  return self.roberta.rotary_emb_base
 
312
  config = XLMRobertaFlashConfig.from_pretrained(
313
  pretrained_model_name_or_path, *model_args, **kwargs
314
  )
 
315
  if config.load_trained_adapters:
316
  return super().from_pretrained(
317
  pretrained_model_name_or_path, *model_args, **kwargs
318
  )
319
  else:
320
+ roberta = XLMRobertaModel.from_pretrained(
321
+ pretrained_model_name_or_path, *model_args, **kwargs
322
+ )
323
  return cls(config, roberta=roberta)
324
 
325
  def _register_lora(self, num_adaptations, rank, dropout_p, alpha):
 
375
  if task_type:
376
  task_id = self._adaptation_map[task_type]
377
  num_examples = 1 if isinstance(sentences, str) else len(sentences)
378
+ adapter_mask = torch.full(
379
+ (num_examples,), task_id, dtype=torch.int32, device=self.device
380
+ )
381
+ return self.roberta.encode(
382
+ sentences, *args, adapter_mask=adapter_mask, **kwargs
383
+ )
modeling_xlm_roberta.py CHANGED
@@ -13,39 +13,29 @@ import re
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
- import numpy as np
17
 
 
18
  import torch
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
  import torch.utils.checkpoint
22
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
23
- from einops import rearrange
24
- from transformers import PretrainedConfig, AutoTokenizer
 
25
  from transformers.modeling_utils import PreTrainedModel
26
- from transformers.modeling_outputs import MaskedLMOutput,SequenceClassifierOutput
27
- from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaLMHead
28
-
29
  from transformers.models.bert.modeling_bert import (
30
- BaseModelOutputWithPoolingAndCrossAttentions,
31
- BertForPreTrainingOutput,
32
- )
33
 
34
- from typing import List, Optional, Tuple, Union
35
-
36
- from .xlm_padding import (
37
- index_first_axis,
38
- index_first_axis_residual,
39
- pad_input,
40
- unpad_input,
41
- )
42
- from .configuration_xlm_roberta import XLMRobertaFlashConfig
43
  from .block import Block
 
44
  from .embedding import XLMRobertaEmbeddings
45
  from .mha import MHA
46
  from .mlp import FusedMLP, Mlp
47
- from .stochastic_depth import StochasticDepth
48
- from .rotary import RotaryEmbedding
49
 
50
  try:
51
  from flash_attn.ops.fused_dense import FusedDense
@@ -79,7 +69,7 @@ def get_use_flash_attn(config: XLMRobertaFlashConfig):
79
  return False
80
  if importlib.util.find_spec("flash_attn") is None:
81
  logger.warning(
82
- 'flash_attn is not installed. Using PyTorch native attention implementation.'
83
  )
84
  return False
85
  return True
@@ -109,7 +99,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
109
  fused_bias_fc=fused_bias_fc,
110
  use_flash_attn=use_flash_attn,
111
  return_residual=return_residual,
112
- use_alibi=config.position_embedding_type == 'alibi',
113
  **rotary_kwargs,
114
  )
115
  return mixer_cls
@@ -204,15 +194,17 @@ class XLMRobertaEncoder(nn.Module):
204
  def gradient_checkpointing(self, value):
205
  self._grad_checkpointing = value
206
 
207
- def forward(self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None):
 
 
208
  """If subset_mask is not None, we only want output for the subset of the sequence.
209
  This means that we only compute the last layer output for these tokens.
210
  subset_mask: (batch, seqlen), dtype=torch.bool
211
  """
212
  if key_padding_mask is None or not self.use_flash_attn:
213
- mixer_kwargs = {'adapter_mask': adapter_mask}
214
  if key_padding_mask is not None:
215
- mixer_kwargs['key_padding_mask'] = key_padding_mask.bool()
216
  for layer in self.layers:
217
  if self._grad_checkpointing:
218
  hidden_states = torch.utils.checkpoint.checkpoint(
@@ -227,10 +219,14 @@ class XLMRobertaEncoder(nn.Module):
227
  hidden_states = hidden_states[subset_mask]
228
  else:
229
  batch, seqlen = hidden_states.shape[:2]
230
- hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = unpad_input(
231
- hidden_states, key_padding_mask, adapter_mask
232
  )
233
- mixer_kwargs = {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen_in_batch, "adapter_mask": cu_adapter_mask}
 
 
 
 
234
 
235
  if subset_mask is None:
236
  for layer in self.layers:
@@ -315,12 +311,18 @@ class XLMRobertaPooler(nn.Module):
315
  if adapter_mask is not None:
316
  unique_tasks = torch.unique(adapter_mask)
317
  pool_dtype = next(self.dense.parameters()).dtype
318
- pooled_output = torch.empty(first_token_tensor.shape[0], self.dense.out_features,
319
- dtype=pool_dtype, device=first_token_tensor.device)
 
 
 
 
320
  for task_id in unique_tasks:
321
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
322
  task_first_token_tensor = first_token_tensor[task_indices]
323
- task_pooled_output = self.dense(task_first_token_tensor, task_id=task_id)
 
 
324
  pooled_output[task_indices] = task_pooled_output
325
  else:
326
  pooled_output = self.dense(first_token_tensor)
@@ -413,12 +415,11 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
413
  *args,
414
  **kwargs,
415
  ):
416
- if not 'torch_dtype' in kwargs:
417
- kwargs['torch_dtype'] = 'auto'
418
  return super().from_pretrained(*args, **kwargs)
419
 
420
 
421
-
422
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
423
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
424
  super().__init__(config)
@@ -439,7 +440,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
439
  self.embeddings = XLMRobertaEmbeddings(
440
  config.hidden_size,
441
  config.vocab_size,
442
- config.max_position_embeddings if config.position_embedding_type == 'absolute' else -1,
 
 
 
 
443
  config.type_vocab_size,
444
  padding_idx=config.pad_token_id,
445
  )
@@ -449,16 +454,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
449
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
450
 
451
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
452
- self.tokenizer = AutoTokenizer.from_pretrained(self.name_or_path, trust_remote_code=True)
 
 
453
  self._rotary_emb_base = config.rotary_emb_base
454
 
455
  @torch.inference_mode()
456
  def encode(
457
- self: 'XLMRobertaModel',
458
  sentences: Union[str, List[str]],
459
  batch_size: int = 32,
460
  show_progress_bar: Optional[bool] = None,
461
- output_value: str = 'sentence_embedding',
462
  convert_to_numpy: bool = True,
463
  convert_to_tensor: bool = False,
464
  device: Optional[torch.device] = None,
@@ -516,12 +523,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
516
  if convert_to_tensor:
517
  convert_to_numpy = False
518
 
519
- if output_value != 'sentence_embedding':
520
  convert_to_tensor = False
521
  convert_to_numpy = False
522
 
523
  input_was_string = False
524
- if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
525
  sentences = [sentences]
526
  input_was_string = True
527
 
@@ -532,11 +539,11 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
532
  inverse_permutation = np.argsort(permutation)
533
  sentences = [sentences[idx] for idx in permutation]
534
 
535
- tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
536
- tokenizer_kwargs['max_length'] = tokenizer_kwargs.get(
537
- 'max_length', self.tokenizer.init_kwargs.get('model_max_length', 8192)
538
  )
539
- tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
540
 
541
  all_embeddings = []
542
 
@@ -550,11 +557,13 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
550
  )
551
  else:
552
  range_iter = range(0, len(sentences), batch_size)
553
- lora_arguments = {'adapter_mask': adapter_mask} if adapter_mask is not None else {}
 
 
554
  for i in range_iter:
555
  encoded_input = self.tokenizer(
556
  sentences[i : i + batch_size],
557
- return_tensors='pt',
558
  **tokenizer_kwargs,
559
  ).to(self.device)
560
  token_embs = self.forward(**encoded_input, **lora_arguments)[0]
@@ -562,18 +571,18 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
562
  # Accumulate in fp32 to avoid overflow
563
  token_embs = token_embs.float()
564
 
565
- if output_value == 'token_embeddings':
566
  raise NotImplementedError
567
  elif output_value is None:
568
  raise NotImplementedError
569
  else:
570
- if self.config.emb_pooler == 'cls':
571
  embeddings = self.cls_pooling(
572
- token_embs, encoded_input['attention_mask']
573
  )
574
  else:
575
  embeddings = self.mean_pooling(
576
- token_embs, encoded_input['attention_mask']
577
  )
578
 
579
  if normalize_embeddings:
@@ -603,14 +612,16 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
603
  def truncate_embeddings(self, embeddings, truncate_dim):
604
  if not self.config.matryoshka_dimensions:
605
  logger.warning(
606
- 'Matryoshka embeddings are not supported, so dimension truncation will not be performed.'
607
  )
608
  return embeddings
609
  elif truncate_dim in self.config.matryoshka_dimensions:
610
  return [tensor[:truncate_dim] for tensor in embeddings]
611
  else:
612
- raise ValueError(f'The provided `truncate_dim` value of {truncate_dim} is not supported. '
613
- f'Supported dimensions are {self.config.matryoshka_dimensions}.')
 
 
614
 
615
  def mean_pooling(
616
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
@@ -622,10 +633,8 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
622
  input_mask_expanded.sum(1), min=1e-9
623
  )
624
 
625
- def cls_pooling(
626
- self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
627
- ):
628
- return token_embeddings[:,0]
629
 
630
  @property
631
  def rotary_emb_base(self):
@@ -635,7 +644,7 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
635
  def rotary_emb_base(self, base):
636
  if not isinstance(base, (int, float)):
637
  raise TypeError("Base must be an integer or float")
638
- logger.info(f'Changing RoPE base value to {base}')
639
  for layer in self.encoder.layers:
640
  layer.mixer.rotary_emb.base = base
641
  self._rotary_emb_base = base
@@ -655,12 +664,12 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
655
  layer output for these tokens.
656
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
657
  """
658
- adapter_mask = kwargs.pop('adapter_mask', None)
659
  if kwargs:
660
  for key, value in kwargs.items():
661
  if value is not None:
662
  logger.warning(
663
- 'Flash attention implementation does not support kwargs: %s',
664
  key,
665
  )
666
 
@@ -669,7 +678,10 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
669
  )
670
 
671
  hidden_states = self.embeddings(
672
- input_ids, position_ids=position_ids, token_type_ids=token_type_ids, adapter_mask=adapter_mask
 
 
 
673
  )
674
  # TD [2022-12:18]: Don't need to force residual in fp32
675
  # BERT puts embedding LayerNorm before embedding dropout.
@@ -693,12 +705,17 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
693
  subset_mask = None
694
 
695
  sequence_output = self.encoder(
696
- hidden_states, key_padding_mask=attention_mask, subset_mask=subset_mask, adapter_mask=adapter_mask
 
 
 
697
  )
698
 
699
  if masked_tokens_mask is None:
700
  pooled_output = (
701
- self.pooler(sequence_output, adapter_mask=adapter_mask) if self.pooler is not None else None
 
 
702
  )
703
  else:
704
  # TD [2022-03-01]: the indexing here is very tricky.
@@ -712,7 +729,9 @@ class XLMRobertaModel(XLMRobertaPreTrainedModel):
712
  pool_input = sequence_output[first_col_mask[subset_mask]]
713
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
714
  pooled_output = (
715
- self.pooler(pool_input, pool=False, adapter_mask=adapter_mask) if self.pooler is not None else None
 
 
716
  )
717
 
718
  if not return_dict:
@@ -817,103 +836,6 @@ class XLMRobertaForMaskedLM(XLMRobertaPreTrainedModel):
817
  )
818
 
819
 
820
- # class XLMRobertaForPreTraining(XLMRobertaPreTrainedModel):
821
- # def __init__(self, config: XLMRobertaFlashConfig):
822
- # super().__init__(config)
823
- # # If dense_seq_output, we only need to pass the hidden states for the masked out tokens
824
- # # (around 15%) to the classifier heads.
825
- # self.dense_seq_output = getattr(config, "dense_seq_output", False)
826
- # # If last_layer_subset, we only need the compute the last layer for a subset of tokens
827
- # # (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
828
- # self.last_layer_subset = getattr(config, "last_layer_subset", False)
829
- # if self.last_layer_subset:
830
- # assert self.dense_seq_output, "last_layer_subset requires dense_seq_output"
831
- # use_xentropy = getattr(config, "use_xentropy", False)
832
- # if use_xentropy and CrossEntropyLoss is None:
833
- # raise ImportError("xentropy_cuda is not installed")
834
- # loss_cls = (
835
- # nn.CrossEntropyLoss
836
- # if not use_xentropy
837
- # else partial(CrossEntropyLoss, inplace_backward=True)
838
- # )
839
- #
840
- # self.xlm = XLMRobertaModel(config)
841
- # self.cls = XLMRobertaPreTrainingHeads(config)
842
- # self.mlm_loss = loss_cls(ignore_index=0)
843
- # self.nsp_loss = loss_cls(ignore_index=-1)
844
- #
845
- # # Initialize weights and apply final processing
846
- # self.apply(partial(_init_weights, initializer_range=config.initializer_range))
847
- # self.tie_weights()
848
- #
849
- # def tie_weights(self):
850
- # self.cls.predictions.decoder.weight = self.xlm.embeddings.word_embeddings.weight
851
- #
852
- # def forward(
853
- # self,
854
- # input_ids,
855
- # position_ids=None,
856
- # token_type_ids=None,
857
- # attention_mask=None,
858
- # labels=None,
859
- # next_sentence_label=None,
860
- # ):
861
- # """
862
- # If labels are provided, they must be 0 for masked out tokens (as specified in the attention
863
- # mask).
864
- # Outputs:
865
- # if `labels` and `next_sentence_label` are not `None`:
866
- # Outputs the total_loss which is the sum of the masked language modeling loss and the next
867
- # sentence classification loss.
868
- # if `labels` or `next_sentence_label` is `None`:
869
- # Outputs a tuple comprising
870
- # - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
871
- # - the next sentence classification logits of shape [batch_size, 2].
872
- #
873
- # """
874
- # masked_tokens_mask = labels > 0 if (self.last_layer_subset and labels is not None) else None
875
- # outputs = self.xlm(
876
- # input_ids,
877
- # position_ids=position_ids,
878
- # token_type_ids=token_type_ids,
879
- # attention_mask=attention_mask.bool() if attention_mask is not None else None,
880
- # masked_tokens_mask=masked_tokens_mask,
881
- # )
882
- # sequence_output, pooled_output = outputs.last_hidden_state, outputs.pooler_output
883
- # if self.dense_seq_output and labels is not None:
884
- # masked_token_idx = torch.nonzero(labels.flatten() > 0, as_tuple=False).flatten()
885
- # if not self.last_layer_subset:
886
- # sequence_output = index_first_axis(
887
- # rearrange(sequence_output, "b s d -> (b s) d"), masked_token_idx
888
- # )
889
- # prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
890
- #
891
- # total_loss = None
892
- # if labels is not None and next_sentence_label is not None:
893
- # if (
894
- # self.dense_seq_output and labels is not None
895
- # ): # prediction_scores are already flattened
896
- # masked_lm_loss = self.mlm_loss(
897
- # prediction_scores, labels.flatten()[masked_token_idx]
898
- # )
899
- # else:
900
- # masked_lm_loss = self.mlm_loss(
901
- # rearrange(prediction_scores, "... v -> (...) v"),
902
- # rearrange(labels, "... -> (...)"),
903
- # )
904
- # next_sentence_loss = self.nsp_loss(
905
- # rearrange(seq_relationship_score, "... t -> (...) t"),
906
- # rearrange(next_sentence_label, "... -> (...)"),
907
- # )
908
- # total_loss = masked_lm_loss.float() + next_sentence_loss.float()
909
- #
910
- # return BertForPreTrainingOutput(
911
- # loss=total_loss,
912
- # prediction_logits=prediction_scores,
913
- # seq_relationship_logits=seq_relationship_score,
914
- # )
915
-
916
-
917
  def remap_state_dict(state_dict, config: PretrainedConfig):
918
  """
919
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
@@ -1065,47 +987,47 @@ def inv_remap_state_dict(state_dict, config: PretrainedConfig):
1065
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
1066
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
1067
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
1068
- state_dict[
1069
- f"bert.encoder.layers.{d}.attention.self.query.weight"
1070
- ] = Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
1071
- state_dict[
1072
- f"bert.encoder.layers.{d}.attention.self.key.weight"
1073
- ] = Wqkv_weights[
1074
- Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
1075
- ]
1076
- state_dict[
1077
- f"bert.encoder.layers.{d}.attention.self.value.weight"
1078
- ] = Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
1079
- state_dict[
1080
- f"bert.encoder.layers.{d}.attention.self.query.bias"
1081
- ] = Wqkv_biases[: Wqkv_biases.shape[0] // 3]
1082
- state_dict[
1083
- f"bert.encoder.layers.{d}.attention.self.key.bias"
1084
- ] = Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
1085
- state_dict[
1086
- f"bert.encoder.layers.{d}.attention.self.value.bias"
1087
- ] = Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
1088
  else:
1089
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1090
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1091
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1092
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1093
- state_dict[
1094
- f"bert.encoder.layers.{d}.attention.self.query.weight"
1095
- ] = Wq_weight
1096
- state_dict[
1097
- f"bert.encoder.layers.{d}.attention.self.key.weight"
1098
- ] = Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1099
- state_dict[
1100
- f"bert.encoder.layers.{d}.attention.self.value.weight"
1101
- ] = Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1102
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1103
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1104
  : Wkv_biases.shape[0] // 2
1105
  ]
1106
- state_dict[
1107
- f"bert.encoder.layers.{d}.attention.self.value.bias"
1108
- ] = Wkv_biases[Wkv_biases.shape[0] // 2 :]
1109
 
1110
  def inv_key_mapping_ln(key):
1111
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
@@ -1294,4 +1216,4 @@ class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
1294
  logits=logits,
1295
  hidden_states=outputs.hidden_states,
1296
  attentions=outputs.attentions,
1297
- )
 
13
  from collections import OrderedDict
14
  from collections.abc import Sequence
15
  from functools import partial
16
+ from typing import List, Optional, Tuple, Union
17
 
18
+ import numpy as np
19
  import torch
20
  import torch.nn as nn
21
  import torch.nn.functional as F
22
  import torch.utils.checkpoint
23
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
24
+ from transformers import AutoTokenizer, PretrainedConfig
25
+ from transformers.modeling_outputs import (MaskedLMOutput,
26
+ SequenceClassifierOutput)
27
  from transformers.modeling_utils import PreTrainedModel
 
 
 
28
  from transformers.models.bert.modeling_bert import (
29
+ BaseModelOutputWithPoolingAndCrossAttentions, BertForPreTrainingOutput)
30
+ from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
+ XLMRobertaLMHead
32
 
 
 
 
 
 
 
 
 
 
33
  from .block import Block
34
+ from .configuration_xlm_roberta import XLMRobertaFlashConfig
35
  from .embedding import XLMRobertaEmbeddings
36
  from .mha import MHA
37
  from .mlp import FusedMLP, Mlp
38
+ from .xlm_padding import index_first_axis_residual, pad_input, unpad_input
 
39
 
40
  try:
41
  from flash_attn.ops.fused_dense import FusedDense
 
69
  return False
70
  if importlib.util.find_spec("flash_attn") is None:
71
  logger.warning(
72
+ "flash_attn is not installed. Using PyTorch native attention implementation."
73
  )
74
  return False
75
  return True
 
99
  fused_bias_fc=fused_bias_fc,
100
  use_flash_attn=use_flash_attn,
101
  return_residual=return_residual,
102
+ use_alibi=config.position_embedding_type == "alibi",
103
  **rotary_kwargs,
104
  )
105
  return mixer_cls
 
194
  def gradient_checkpointing(self, value):
195
  self._grad_checkpointing = value
196
 
197
+ def forward(
198
+ self, hidden_states, key_padding_mask=None, subset_mask=None, adapter_mask=None
199
+ ):
200
  """If subset_mask is not None, we only want output for the subset of the sequence.
201
  This means that we only compute the last layer output for these tokens.
202
  subset_mask: (batch, seqlen), dtype=torch.bool
203
  """
204
  if key_padding_mask is None or not self.use_flash_attn:
205
+ mixer_kwargs = {"adapter_mask": adapter_mask}
206
  if key_padding_mask is not None:
207
+ mixer_kwargs["key_padding_mask"] = key_padding_mask.bool()
208
  for layer in self.layers:
209
  if self._grad_checkpointing:
210
  hidden_states = torch.utils.checkpoint.checkpoint(
 
219
  hidden_states = hidden_states[subset_mask]
220
  else:
221
  batch, seqlen = hidden_states.shape[:2]
222
+ hidden_states, indices, cu_seqlens, max_seqlen_in_batch, cu_adapter_mask = (
223
+ unpad_input(hidden_states, key_padding_mask, adapter_mask)
224
  )
225
+ mixer_kwargs = {
226
+ "cu_seqlens": cu_seqlens,
227
+ "max_seqlen": max_seqlen_in_batch,
228
+ "adapter_mask": cu_adapter_mask,
229
+ }
230
 
231
  if subset_mask is None:
232
  for layer in self.layers:
 
311
  if adapter_mask is not None:
312
  unique_tasks = torch.unique(adapter_mask)
313
  pool_dtype = next(self.dense.parameters()).dtype
314
+ pooled_output = torch.empty(
315
+ first_token_tensor.shape[0],
316
+ self.dense.out_features,
317
+ dtype=pool_dtype,
318
+ device=first_token_tensor.device,
319
+ )
320
  for task_id in unique_tasks:
321
  task_indices = (adapter_mask == task_id).nonzero(as_tuple=True)[0]
322
  task_first_token_tensor = first_token_tensor[task_indices]
323
+ task_pooled_output = self.dense(
324
+ task_first_token_tensor, task_id=task_id
325
+ )
326
  pooled_output[task_indices] = task_pooled_output
327
  else:
328
  pooled_output = self.dense(first_token_tensor)
 
415
  *args,
416
  **kwargs,
417
  ):
418
+ if not "torch_dtype" in kwargs:
419
+ kwargs["torch_dtype"] = "auto"
420
  return super().from_pretrained(*args, **kwargs)
421
 
422
 
 
423
  class XLMRobertaModel(XLMRobertaPreTrainedModel):
424
  def __init__(self, config: XLMRobertaFlashConfig, add_pooling_layer=True):
425
  super().__init__(config)
 
440
  self.embeddings = XLMRobertaEmbeddings(
441
  config.hidden_size,
442
  config.vocab_size,
443
+ (
444
+ config.max_position_embeddings
445
+ if config.position_embedding_type == "absolute"
446
+ else -1
447
+ ),
448
  config.type_vocab_size,
449
  padding_idx=config.pad_token_id,
450
  )
 
454
  self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
455
 
456
  self.apply(partial(_init_weights, initializer_range=config.initializer_range))
457
+ self.tokenizer = AutoTokenizer.from_pretrained(
458
+ self.name_or_path, trust_remote_code=True
459
+ )
460
  self._rotary_emb_base = config.rotary_emb_base
461
 
462
  @torch.inference_mode()
463
  def encode(
464
+ self: "XLMRobertaModel",
465
  sentences: Union[str, List[str]],
466
  batch_size: int = 32,
467
  show_progress_bar: Optional[bool] = None,
468
+ output_value: str = "sentence_embedding",
469
  convert_to_numpy: bool = True,
470
  convert_to_tensor: bool = False,
471
  device: Optional[torch.device] = None,
 
523
  if convert_to_tensor:
524
  convert_to_numpy = False
525
 
526
+ if output_value != "sentence_embedding":
527
  convert_to_tensor = False
528
  convert_to_numpy = False
529
 
530
  input_was_string = False
531
+ if isinstance(sentences, str) or not hasattr(sentences, "__len__"):
532
  sentences = [sentences]
533
  input_was_string = True
534
 
 
539
  inverse_permutation = np.argsort(permutation)
540
  sentences = [sentences[idx] for idx in permutation]
541
 
542
+ tokenizer_kwargs["padding"] = tokenizer_kwargs.get("padding", True)
543
+ tokenizer_kwargs["max_length"] = tokenizer_kwargs.get(
544
+ "max_length", self.tokenizer.init_kwargs.get("model_max_length", 8192)
545
  )
546
+ tokenizer_kwargs["truncation"] = tokenizer_kwargs.get("truncation", True)
547
 
548
  all_embeddings = []
549
 
 
557
  )
558
  else:
559
  range_iter = range(0, len(sentences), batch_size)
560
+ lora_arguments = (
561
+ {"adapter_mask": adapter_mask} if adapter_mask is not None else {}
562
+ )
563
  for i in range_iter:
564
  encoded_input = self.tokenizer(
565
  sentences[i : i + batch_size],
566
+ return_tensors="pt",
567
  **tokenizer_kwargs,
568
  ).to(self.device)
569
  token_embs = self.forward(**encoded_input, **lora_arguments)[0]
 
571
  # Accumulate in fp32 to avoid overflow
572
  token_embs = token_embs.float()
573
 
574
+ if output_value == "token_embeddings":
575
  raise NotImplementedError
576
  elif output_value is None:
577
  raise NotImplementedError
578
  else:
579
+ if self.config.emb_pooler == "cls":
580
  embeddings = self.cls_pooling(
581
+ token_embs, encoded_input["attention_mask"]
582
  )
583
  else:
584
  embeddings = self.mean_pooling(
585
+ token_embs, encoded_input["attention_mask"]
586
  )
587
 
588
  if normalize_embeddings:
 
612
  def truncate_embeddings(self, embeddings, truncate_dim):
613
  if not self.config.matryoshka_dimensions:
614
  logger.warning(
615
+ "Matryoshka embeddings are not supported, so dimension truncation will not be performed."
616
  )
617
  return embeddings
618
  elif truncate_dim in self.config.matryoshka_dimensions:
619
  return [tensor[:truncate_dim] for tensor in embeddings]
620
  else:
621
+ raise ValueError(
622
+ f"The provided `truncate_dim` value of {truncate_dim} is not supported. "
623
+ f"Supported dimensions are {self.config.matryoshka_dimensions}."
624
+ )
625
 
626
  def mean_pooling(
627
  self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
 
633
  input_mask_expanded.sum(1), min=1e-9
634
  )
635
 
636
+ def cls_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
637
+ return token_embeddings[:, 0]
 
 
638
 
639
  @property
640
  def rotary_emb_base(self):
 
644
  def rotary_emb_base(self, base):
645
  if not isinstance(base, (int, float)):
646
  raise TypeError("Base must be an integer or float")
647
+ logger.info(f"Changing RoPE base value to {base}")
648
  for layer in self.encoder.layers:
649
  layer.mixer.rotary_emb.base = base
650
  self._rotary_emb_base = base
 
664
  layer output for these tokens.
665
  masked_tokens_mask: (batch, seqlen), dtype=torch.bool
666
  """
667
+ adapter_mask = kwargs.pop("adapter_mask", None)
668
  if kwargs:
669
  for key, value in kwargs.items():
670
  if value is not None:
671
  logger.warning(
672
+ "Flash attention implementation does not support kwargs: %s",
673
  key,
674
  )
675
 
 
678
  )
679
 
680
  hidden_states = self.embeddings(
681
+ input_ids,
682
+ position_ids=position_ids,
683
+ token_type_ids=token_type_ids,
684
+ adapter_mask=adapter_mask,
685
  )
686
  # TD [2022-12:18]: Don't need to force residual in fp32
687
  # BERT puts embedding LayerNorm before embedding dropout.
 
705
  subset_mask = None
706
 
707
  sequence_output = self.encoder(
708
+ hidden_states,
709
+ key_padding_mask=attention_mask,
710
+ subset_mask=subset_mask,
711
+ adapter_mask=adapter_mask,
712
  )
713
 
714
  if masked_tokens_mask is None:
715
  pooled_output = (
716
+ self.pooler(sequence_output, adapter_mask=adapter_mask)
717
+ if self.pooler is not None
718
+ else None
719
  )
720
  else:
721
  # TD [2022-03-01]: the indexing here is very tricky.
 
729
  pool_input = sequence_output[first_col_mask[subset_mask]]
730
  sequence_output = sequence_output[masked_tokens_mask[subset_mask]]
731
  pooled_output = (
732
+ self.pooler(pool_input, pool=False, adapter_mask=adapter_mask)
733
+ if self.pooler is not None
734
+ else None
735
  )
736
 
737
  if not return_dict:
 
836
  )
837
 
838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
839
  def remap_state_dict(state_dict, config: PretrainedConfig):
840
  """
841
  Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
 
987
  if not last_layer_subset or d != (config.num_hidden_layers - 1):
988
  Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
989
  Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
990
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
991
+ Wqkv_weights[: Wqkv_weights.shape[0] // 3, :]
992
+ )
993
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
994
+ Wqkv_weights[
995
+ Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
996
+ ]
997
+ )
998
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
999
+ Wqkv_weights[2 * Wqkv_weights.shape[0] // 3 :, :]
1000
+ )
1001
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = (
1002
+ Wqkv_biases[: Wqkv_biases.shape[0] // 3]
1003
+ )
1004
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = (
1005
+ Wqkv_biases[Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3]
1006
+ )
1007
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1008
+ Wqkv_biases[2 * Wqkv_biases.shape[0] // 3 :]
1009
+ )
1010
  else:
1011
  Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
1012
  Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
1013
  Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
1014
  Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
1015
+ state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = (
1016
+ Wq_weight
1017
+ )
1018
+ state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = (
1019
+ Wkv_weights[: Wkv_weights.shape[0] // 2, :]
1020
+ )
1021
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = (
1022
+ Wkv_weights[Wkv_weights.shape[0] // 2 :, :]
1023
+ )
1024
  state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
1025
  state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
1026
  : Wkv_biases.shape[0] // 2
1027
  ]
1028
+ state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = (
1029
+ Wkv_biases[Wkv_biases.shape[0] // 2 :]
1030
+ )
1031
 
1032
  def inv_key_mapping_ln(key):
1033
  key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
 
1216
  logits=logits,
1217
  hidden_states=outputs.hidden_states,
1218
  attentions=outputs.attentions,
1219
+ )
modeling_xlm_roberta_for_glue.py DELETED
@@ -1,109 +0,0 @@
1
- from typing import Optional, Union, Tuple
2
-
3
- import torch
4
- from torch import nn
5
- from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
6
- from transformers.modeling_outputs import SequenceClassifierOutput, QuestionAnsweringModelOutput, TokenClassifierOutput
7
-
8
- from .modeling_xlm_roberta import XLMRobertaPreTrainedModel, XLMRobertaModel
9
- from .configuration_xlm_roberta import XLMRobertaFlashConfig
10
-
11
-
12
- class XLMRobertaForSequenceClassification(XLMRobertaPreTrainedModel):
13
- def __init__(self, config: XLMRobertaFlashConfig):
14
- super().__init__(config)
15
- self.num_labels = config.num_labels
16
- self.config = config
17
-
18
- self.roberta = XLMRobertaModel(config)
19
- classifier_dropout = (
20
- config.classifier_dropout
21
- if config.classifier_dropout is not None
22
- else config.hidden_dropout_prob
23
- )
24
- self.dropout = nn.Dropout(classifier_dropout)
25
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
26
-
27
- # Initialize weights and apply final processing
28
- self.post_init()
29
-
30
-
31
- def forward(
32
- self,
33
- input_ids: Optional[torch.Tensor] = None,
34
- attention_mask: Optional[torch.Tensor] = None,
35
- token_type_ids: Optional[torch.Tensor] = None,
36
- position_ids: Optional[torch.Tensor] = None,
37
- head_mask: Optional[torch.Tensor] = None,
38
- inputs_embeds: Optional[torch.Tensor] = None,
39
- labels: Optional[torch.Tensor] = None,
40
- output_attentions: Optional[bool] = None,
41
- output_hidden_states: Optional[bool] = None,
42
- return_dict: Optional[bool] = None,
43
- ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
44
- r"""
45
- labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
- Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
47
- config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
- `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
- """
50
- return_dict = (
51
- return_dict if return_dict is not None else self.config.use_return_dict
52
- )
53
-
54
- assert head_mask is None
55
- assert inputs_embeds is None
56
- assert output_attentions is None
57
- assert output_hidden_states is None
58
- assert return_dict
59
- outputs = self.roberta(
60
- input_ids,
61
- attention_mask=attention_mask,
62
- token_type_ids=token_type_ids,
63
- position_ids=position_ids,
64
- head_mask=head_mask,
65
- inputs_embeds=inputs_embeds,
66
- output_attentions=output_attentions,
67
- output_hidden_states=output_hidden_states,
68
- return_dict=return_dict,
69
- )
70
-
71
- pooled_output = outputs[1]
72
-
73
- pooled_output = self.dropout(pooled_output)
74
- logits = self.classifier(pooled_output)
75
-
76
- loss = None
77
- if labels is not None:
78
- if self.config.problem_type is None:
79
- if self.num_labels == 1:
80
- self.config.problem_type = "regression"
81
- elif self.num_labels > 1 and (
82
- labels.dtype == torch.long or labels.dtype == torch.int
83
- ):
84
- self.config.problem_type = "single_label_classification"
85
- else:
86
- self.config.problem_type = "multi_label_classification"
87
-
88
- if self.config.problem_type == "regression":
89
- loss_fct = MSELoss()
90
- if self.num_labels == 1:
91
- loss = loss_fct(logits.squeeze(), labels.squeeze())
92
- else:
93
- loss = loss_fct(logits, labels)
94
- elif self.config.problem_type == "single_label_classification":
95
- loss_fct = CrossEntropyLoss()
96
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
97
- elif self.config.problem_type == "multi_label_classification":
98
- loss_fct = BCEWithLogitsLoss()
99
- loss = loss_fct(logits, labels)
100
- if not return_dict:
101
- output = (logits,) + outputs[2:]
102
- return ((loss,) + output) if loss is not None else output
103
-
104
- return SequenceClassifierOutput(
105
- loss=loss,
106
- logits=logits,
107
- hidden_states=outputs.hidden_states,
108
- attentions=outputs.attentions,
109
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rotary.py CHANGED
@@ -1,4 +1,7 @@
1
- # Adapted from https://github.com/Dao-AILab/flash-attention/pull/556
 
 
 
2
  # Copyright (c) 2023, Tri Dao.
3
 
4
  import math
@@ -11,8 +14,9 @@ if torch.cuda.is_available():
11
  try:
12
  from flash_attn.ops.triton.rotary import apply_rotary
13
  except ImportError:
 
14
  def apply_rotary(*args, **kwargs):
15
- raise RuntimeError('RoPE requires flash-attention to be installed')
16
 
17
 
18
  def rotate_half(x, interleaved=False):
@@ -21,7 +25,9 @@ def rotate_half(x, interleaved=False):
21
  return torch.cat((-x2, x1), dim=-1)
22
  else:
23
  x1, x2 = x[..., ::2], x[..., 1::2]
24
- return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
 
 
25
 
26
 
27
  def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
@@ -32,13 +38,20 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
32
  ro_dim = cos.shape[-1] * 2
33
  assert ro_dim <= x.shape[-1]
34
  cos, sin = (
35
- cos[:x.shape[1]],
36
- sin[:x.shape[1]],
 
 
 
 
 
 
37
  )
38
- cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
39
- sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
40
  return torch.cat(
41
- [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]],
 
 
 
42
  dim=-1,
43
  )
44
 
@@ -68,7 +81,9 @@ class ApplyRotaryEmb(torch.autograd.Function):
68
  )
69
 
70
  if isinstance(seqlen_offsets, int):
71
- ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
 
 
72
  ctx.seqlen_offsets = seqlen_offsets
73
  else:
74
  ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
@@ -336,7 +351,9 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
336
  max_seqlen=max_seqlen,
337
  )
338
  if isinstance(seqlen_offsets, int):
339
- ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward
 
 
340
  ctx.seqlen_offsets = seqlen_offsets
341
  else:
342
  ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
@@ -451,7 +468,8 @@ class RotaryEmbedding(torch.nn.Module):
451
  self.interleaved = interleaved
452
  self.scale_base = scale_base
453
  scale = (
454
- (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
 
455
  if scale_base is not None
456
  else None
457
  )
@@ -477,7 +495,10 @@ class RotaryEmbedding(torch.nn.Module):
477
  def _compute_inv_freq(self, device=None):
478
  return 1.0 / (
479
  self.base
480
- ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)
 
 
 
481
  )
482
 
483
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
@@ -516,10 +537,14 @@ class RotaryEmbedding(torch.nn.Module):
516
  self._sin_cached = torch.sin(freqs).to(dtype)
517
  else:
518
  power = (
519
- torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
 
 
520
  - seqlen // 2
521
  ) / self.scale_base
522
- scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
 
 
523
  # We want the multiplication by scale to happen in fp32
524
  self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
525
  self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
@@ -550,7 +575,9 @@ class RotaryEmbedding(torch.nn.Module):
550
  if max_seqlen is not None:
551
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
552
  elif isinstance(seqlen_offset, int):
553
- self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
 
 
554
  if kv is None:
555
  if self.scale is None:
556
  return apply_rotary_emb_qkv_(
@@ -606,4 +633,4 @@ class RotaryEmbedding(torch.nn.Module):
606
  cu_seqlens=cu_seqlens,
607
  max_seqlen=max_seqlen,
608
  )
609
- return q, kv
 
1
+ # This implementation was adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py
2
+ # Commit id: 3566596ad867ee415dd3c12616dd50c610176f6c
3
+ # Rotary varlen support from https://github.com/Dao-AILab/flash-attention/pull/556
4
+
5
  # Copyright (c) 2023, Tri Dao.
6
 
7
  import math
 
14
  try:
15
  from flash_attn.ops.triton.rotary import apply_rotary
16
  except ImportError:
17
+
18
  def apply_rotary(*args, **kwargs):
19
+ raise RuntimeError("RoPE requires flash-attention to be installed")
20
 
21
 
22
  def rotate_half(x, interleaved=False):
 
25
  return torch.cat((-x2, x1), dim=-1)
26
  else:
27
  x1, x2 = x[..., ::2], x[..., 1::2]
28
+ return rearrange(
29
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
30
+ )
31
 
32
 
33
  def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
 
38
  ro_dim = cos.shape[-1] * 2
39
  assert ro_dim <= x.shape[-1]
40
  cos, sin = (
41
+ cos[: x.shape[1]],
42
+ sin[: x.shape[1]],
43
+ )
44
+ cos = repeat(
45
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
46
+ )
47
+ sin = repeat(
48
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
49
  )
 
 
50
  return torch.cat(
51
+ [
52
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
53
+ x[..., ro_dim:],
54
+ ],
55
  dim=-1,
56
  )
57
 
 
81
  )
82
 
83
  if isinstance(seqlen_offsets, int):
84
+ ctx.save_for_backward(
85
+ cos, sin, cu_seqlens
86
+ ) # Can't save int with save_for_backward
87
  ctx.seqlen_offsets = seqlen_offsets
88
  else:
89
  ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
 
351
  max_seqlen=max_seqlen,
352
  )
353
  if isinstance(seqlen_offsets, int):
354
+ ctx.save_for_backward(
355
+ cos, sin, cu_seqlens
356
+ ) # Can't save int with save_for_backward
357
  ctx.seqlen_offsets = seqlen_offsets
358
  else:
359
  ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets)
 
468
  self.interleaved = interleaved
469
  self.scale_base = scale_base
470
  scale = (
471
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
472
+ / (1.4 * dim)
473
  if scale_base is not None
474
  else None
475
  )
 
495
  def _compute_inv_freq(self, device=None):
496
  return 1.0 / (
497
  self.base
498
+ ** (
499
+ torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
500
+ / self.dim
501
+ )
502
  )
503
 
504
  def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
 
537
  self._sin_cached = torch.sin(freqs).to(dtype)
538
  else:
539
  power = (
540
+ torch.arange(
541
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
542
+ )
543
  - seqlen // 2
544
  ) / self.scale_base
545
+ scale = self.scale.to(device=power.device) ** rearrange(
546
+ power, "s -> s 1"
547
+ )
548
  # We want the multiplication by scale to happen in fp32
549
  self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
550
  self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
 
575
  if max_seqlen is not None:
576
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
577
  elif isinstance(seqlen_offset, int):
578
+ self._update_cos_sin_cache(
579
+ seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype
580
+ )
581
  if kv is None:
582
  if self.scale is None:
583
  return apply_rotary_emb_qkv_(
 
633
  cu_seqlens=cu_seqlens,
634
  max_seqlen=max_seqlen,
635
  )
636
+ return q, kv
stochastic_depth.py CHANGED
@@ -34,7 +34,7 @@
34
 
35
  import torch
36
  import torch.fx
37
- from torch import nn, Tensor
38
 
39
 
40
  def stochastic_depth(
 
34
 
35
  import torch
36
  import torch.fx
37
+ from torch import Tensor, nn
38
 
39
 
40
  def stochastic_depth(
xlm_padding.py CHANGED
@@ -18,7 +18,9 @@ class IndexFirstAxis(torch.autograd.Function):
18
  # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
  # return input[indices]
20
  return torch.gather(
21
- rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)
 
 
22
  ).reshape(-1, *other_shape)
23
 
24
  @staticmethod
@@ -34,7 +36,9 @@ class IndexFirstAxis(torch.autograd.Function):
34
  )
35
  # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
36
  # grad_input[indices] = grad_output
37
- grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output)
 
 
38
  return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
39
 
40
 
@@ -112,9 +116,15 @@ def unpad_input(hidden_states, attention_mask, adapter_mask=None):
112
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
113
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
114
  max_seqlen_in_batch = seqlens_in_batch.max().item()
115
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
116
 
117
- cu_adapter_mask = torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1]) if adapter_mask is not None else None
 
 
 
 
118
 
119
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
120
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
@@ -184,14 +194,18 @@ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_leng
184
  """
185
  length = attention_mask_in_length.sum(dim=-1)
186
  seqlen = attention_mask_in_length.size(-1)
187
- attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length),
188
- seqlen) < length.unsqueeze(
189
- 1)
190
- real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten()
 
 
191
  seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
192
  indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
193
  max_seqlen_in_batch = seqlens_in_batch.max().item()
194
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
 
 
195
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
196
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
197
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
@@ -219,4 +233,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
219
  # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
220
  # output[indices] = hidden_states
221
  output = index_put_first_axis(hidden_states, indices, batch * seqlen)
222
- return rearrange(output, "(b s) ... -> b s ...", b=batch)
 
18
  # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
19
  # return input[indices]
20
  return torch.gather(
21
+ rearrange(input, "b ... -> b (...)"),
22
+ 0,
23
+ repeat(indices, "z -> z d", d=second_dim),
24
  ).reshape(-1, *other_shape)
25
 
26
  @staticmethod
 
36
  )
37
  # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
38
  # grad_input[indices] = grad_output
39
+ grad_input.scatter_(
40
+ 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
41
+ )
42
  return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
43
 
44
 
 
116
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
117
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
118
  max_seqlen_in_batch = seqlens_in_batch.max().item()
119
+ cu_seqlens = F.pad(
120
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
121
+ )
122
 
123
+ cu_adapter_mask = (
124
+ torch.repeat_interleave(adapter_mask, cu_seqlens[1:] - cu_seqlens[:-1])
125
+ if adapter_mask is not None
126
+ else None
127
+ )
128
 
129
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
130
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
 
194
  """
195
  length = attention_mask_in_length.sum(dim=-1)
196
  seqlen = attention_mask_in_length.size(-1)
197
+ attention_mask_2d = torch.arange(
198
+ seqlen, device=length.device, dtype=length.dtype
199
+ ).expand(len(length), seqlen) < length.unsqueeze(1)
200
+ real_indices_idx = torch.nonzero(
201
+ attention_mask_in_length.flatten(), as_tuple=False
202
+ ).flatten()
203
  seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
204
  indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
205
  max_seqlen_in_batch = seqlens_in_batch.max().item()
206
+ cu_seqlens = F.pad(
207
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
208
+ )
209
  # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
210
  # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
211
  # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
 
233
  # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
234
  # output[indices] = hidden_states
235
  output = index_put_first_axis(hidden_states, indices, batch * seqlen)
236
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)