ccdv commited on
Commit
40317f8
·
1 Parent(s): 728b7d8

bos_token + readme

Browse files
Files changed (2) hide show
  1. README.md +12 -7
  2. modeling_lsg_albert.py +40 -13
README.md CHANGED
@@ -69,26 +69,31 @@ model = AutoModel.from_pretrained("ccdv/lsg-albert-base-v2-4096",
69
 
70
  ## Sparse selection type
71
 
72
- There are 5 different sparse selection patterns. The best type is task dependent. \
 
73
  Note that for sequences with length < 2*block_size, the type has no effect.
74
-
75
- * sparsity_type="norm", select highest norm tokens
 
 
 
 
76
  * Works best for a small sparsity_factor (2 to 4)
77
  * Additional parameters:
78
  * None
79
- * sparsity_type="pooling", use average pooling to merge tokens
80
  * Works best for a small sparsity_factor (2 to 4)
81
  * Additional parameters:
82
  * None
83
- * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
84
  * Works best for a large sparsity_factor (4+)
85
  * LSH relies on random projections, thus inference may differ slightly with different seeds
86
  * Additional parameters:
87
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
88
- * sparsity_type="stride", use a striding mecanism per head
89
  * Each head will use different tokens strided by sparsify_factor
90
  * Not recommended if sparsify_factor > num_heads
91
- * sparsity_type="block_stride", use a striding mecanism per head
92
  * Each head will use block of tokens strided by sparsify_factor
93
  * Not recommended if sparsify_factor > num_heads
94
 
 
69
 
70
  ## Sparse selection type
71
 
72
+ There are 6 different sparse selection patterns. The best type is task dependent. \
73
+ If `sparse_block_size=0` or `sparsity_type="none"`, only local attention is considered. \
74
  Note that for sequences with length < 2*block_size, the type has no effect.
75
+ * `sparsity_type="bos_pooling"` (new)
76
+ * weighted average pooling using the BOS token
77
+ * Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
78
+ * Additional parameters:
79
+ * None
80
+ * `sparsity_type="norm"`, select highest norm tokens
81
  * Works best for a small sparsity_factor (2 to 4)
82
  * Additional parameters:
83
  * None
84
+ * `sparsity_type="pooling"`, use average pooling to merge tokens
85
  * Works best for a small sparsity_factor (2 to 4)
86
  * Additional parameters:
87
  * None
88
+ * `sparsity_type="lsh"`, use the LSH algorithm to cluster similar tokens
89
  * Works best for a large sparsity_factor (4+)
90
  * LSH relies on random projections, thus inference may differ slightly with different seeds
91
  * Additional parameters:
92
  * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
93
+ * `sparsity_type="stride"`, use a striding mecanism per head
94
  * Each head will use different tokens strided by sparsify_factor
95
  * Not recommended if sparsify_factor > num_heads
96
+ * `sparsity_type="block_stride"`, use a striding mecanism per head
97
  * Each head will use block of tokens strided by sparsify_factor
98
  * Not recommended if sparsify_factor > num_heads
99
 
modeling_lsg_albert.py CHANGED
@@ -53,16 +53,16 @@ class LSGAlbertConfig(AlbertConfig):
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
- if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
63
- if self.sparsity_factor > self.encoder_attention_heads:
64
  logger.warning(
65
- "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
@@ -463,7 +463,7 @@ class LSGAlbertEmbeddings(AlbertEmbeddings):
463
  return embeddings
464
 
465
 
466
- class LSGAttention(BaseSelfAttention):
467
  '''
468
  Compute local attention with overlapping blocs
469
  Use global attention for tokens with highest norm
@@ -502,15 +502,16 @@ class LSGAttention(BaseSelfAttention):
502
  "lsh": self.get_sparse_tokens_with_lsh,
503
  "stride": self.get_sparse_tokens_with_stride,
504
  "block_stride": self.get_sparse_tokens_with_block_stride,
 
505
  }
506
 
507
  self.sparsity_type = config.sparsity_type
508
- self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
509
 
510
  if config.sparsity_type == "lsh":
511
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
512
 
513
- def get_sparse_tokens_with_norm(self, keys, values, mask):
514
 
515
  if self.sparsity_factor == 1:
516
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -538,7 +539,7 @@ class LSGAttention(BaseSelfAttention):
538
 
539
  return keys, values, mask
540
 
541
- def get_sparse_tokens_with_pooling(self, keys, values, mask):
542
 
543
  if self.sparsity_factor == 1:
544
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -561,7 +562,7 @@ class LSGAttention(BaseSelfAttention):
561
  mask *= torch.finfo(mask.dtype).min
562
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
563
 
564
- def get_sparse_tokens_with_stride(self, keys, values, mask):
565
 
566
  if self.sparsity_factor == 1:
567
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -577,7 +578,7 @@ class LSGAttention(BaseSelfAttention):
577
 
578
  return keys, values, mask
579
 
580
- def get_sparse_tokens_with_block_stride(self, keys, values, mask):
581
 
582
  if self.sparsity_factor == 1:
583
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
@@ -597,11 +598,14 @@ class LSGAttention(BaseSelfAttention):
597
 
598
  return keys, values, mask
599
 
600
- def get_sparse_tokens_with_lsh(self, keys, values, mask):
601
 
602
  if self.sparsity_factor == 1:
603
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
604
 
 
 
 
605
  block_size = min(self.block_size, self.sparse_block_size)
606
  keys = self.chunk(keys, block_size)
607
  values = self.chunk(values, block_size)
@@ -649,6 +653,29 @@ class LSGAttention(BaseSelfAttention):
649
 
650
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  def forward(
653
  self,
654
  hidden_states,
@@ -720,7 +747,7 @@ class LSGAttention(BaseSelfAttention):
720
  sparse_key, sparse_value, sparse_mask = (None, None, None)
721
 
722
  if self.sparse_block_size and self.sparsity_factor > 0:
723
- sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
724
 
725
  # Expand masks on heads
726
  attention_mask = attention_mask.expand(-1, h, -1, -1)
@@ -757,7 +784,7 @@ class LSGAlbertLayer(AlbertLayer):
757
  def __init__(self, config):
758
  super().__init__(config)
759
 
760
- self.attention = LSGAttention(config)
761
 
762
 
763
  class LSGAlbertLayerGroup(AlbertLayerGroup):
 
53
  self.sparsity_factor = sparsity_factor
54
  self.sparsity_type = sparsity_type
55
 
56
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
57
  logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
59
  setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
63
+ if self.sparsity_factor > self.num_attention_heads:
64
  logger.warning(
65
+ "[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
66
  )
67
 
68
  if self.num_global_tokens < 1:
 
463
  return embeddings
464
 
465
 
466
+ class LSGSelfAttention(BaseSelfAttention):
467
  '''
468
  Compute local attention with overlapping blocs
469
  Use global attention for tokens with highest norm
 
502
  "lsh": self.get_sparse_tokens_with_lsh,
503
  "stride": self.get_sparse_tokens_with_stride,
504
  "block_stride": self.get_sparse_tokens_with_block_stride,
505
+ "bos_pooling": self.get_sparse_tokens_with_bos_pooling
506
  }
507
 
508
  self.sparsity_type = config.sparsity_type
509
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
510
 
511
  if config.sparsity_type == "lsh":
512
  self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
513
 
514
+ def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
515
 
516
  if self.sparsity_factor == 1:
517
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
539
 
540
  return keys, values, mask
541
 
542
+ def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
543
 
544
  if self.sparsity_factor == 1:
545
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
562
  mask *= torch.finfo(mask.dtype).min
563
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
564
 
565
+ def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
566
 
567
  if self.sparsity_factor == 1:
568
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
578
 
579
  return keys, values, mask
580
 
581
+ def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
582
 
583
  if self.sparsity_factor == 1:
584
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
 
598
 
599
  return keys, values, mask
600
 
601
+ def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
602
 
603
  if self.sparsity_factor == 1:
604
  return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
605
 
606
+ if self.sparsity_factor == self.sparse_block_size:
607
+ return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
608
+
609
  block_size = min(self.block_size, self.sparse_block_size)
610
  keys = self.chunk(keys, block_size)
611
  values = self.chunk(values, block_size)
 
653
 
654
  return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
655
 
656
+ def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
657
+
658
+ if self.sparsity_factor == 1:
659
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
660
+
661
+ queries = queries.unsqueeze(-3)
662
+ mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
663
+ keys = self.chunk(keys, self.sparsity_factor)
664
+ values = self.chunk(values, self.sparsity_factor)
665
+
666
+ n, h, b, t, d = keys.size()
667
+ scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
668
+ if mask is not None:
669
+ scores = scores + mask
670
+
671
+ scores = torch.softmax(scores, dim=-1)
672
+ keys = scores @ keys
673
+ values = scores @ values
674
+ mask = mask.mean(dim=-1)
675
+ mask[mask != torch.finfo(mask.dtype).min] = 0
676
+
677
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
678
+
679
  def forward(
680
  self,
681
  hidden_states,
 
747
  sparse_key, sparse_value, sparse_mask = (None, None, None)
748
 
749
  if self.sparse_block_size and self.sparsity_factor > 0:
750
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
751
 
752
  # Expand masks on heads
753
  attention_mask = attention_mask.expand(-1, h, -1, -1)
 
784
  def __init__(self, config):
785
  super().__init__(config)
786
 
787
+ self.attention = LSGSelfAttention(config)
788
 
789
 
790
  class LSGAlbertLayerGroup(AlbertLayerGroup):