bos_token + readme
Browse files- README.md +12 -7
- modeling_lsg_albert.py +40 -13
README.md
CHANGED
@@ -69,26 +69,31 @@ model = AutoModel.from_pretrained("ccdv/lsg-albert-base-v2-4096",
|
|
69 |
|
70 |
## Sparse selection type
|
71 |
|
72 |
-
There are
|
|
|
73 |
Note that for sequences with length < 2*block_size, the type has no effect.
|
74 |
-
|
75 |
-
*
|
|
|
|
|
|
|
|
|
76 |
* Works best for a small sparsity_factor (2 to 4)
|
77 |
* Additional parameters:
|
78 |
* None
|
79 |
-
* sparsity_type="pooling"
|
80 |
* Works best for a small sparsity_factor (2 to 4)
|
81 |
* Additional parameters:
|
82 |
* None
|
83 |
-
* sparsity_type="lsh"
|
84 |
* Works best for a large sparsity_factor (4+)
|
85 |
* LSH relies on random projections, thus inference may differ slightly with different seeds
|
86 |
* Additional parameters:
|
87 |
* lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
|
88 |
-
* sparsity_type="stride"
|
89 |
* Each head will use different tokens strided by sparsify_factor
|
90 |
* Not recommended if sparsify_factor > num_heads
|
91 |
-
* sparsity_type="block_stride"
|
92 |
* Each head will use block of tokens strided by sparsify_factor
|
93 |
* Not recommended if sparsify_factor > num_heads
|
94 |
|
|
|
69 |
|
70 |
## Sparse selection type
|
71 |
|
72 |
+
There are 6 different sparse selection patterns. The best type is task dependent. \
|
73 |
+
If `sparse_block_size=0` or `sparsity_type="none"`, only local attention is considered. \
|
74 |
Note that for sequences with length < 2*block_size, the type has no effect.
|
75 |
+
* `sparsity_type="bos_pooling"` (new)
|
76 |
+
* weighted average pooling using the BOS token
|
77 |
+
* Works best in general, especially with a rather large sparsity_factor (8, 16, 32)
|
78 |
+
* Additional parameters:
|
79 |
+
* None
|
80 |
+
* `sparsity_type="norm"`, select highest norm tokens
|
81 |
* Works best for a small sparsity_factor (2 to 4)
|
82 |
* Additional parameters:
|
83 |
* None
|
84 |
+
* `sparsity_type="pooling"`, use average pooling to merge tokens
|
85 |
* Works best for a small sparsity_factor (2 to 4)
|
86 |
* Additional parameters:
|
87 |
* None
|
88 |
+
* `sparsity_type="lsh"`, use the LSH algorithm to cluster similar tokens
|
89 |
* Works best for a large sparsity_factor (4+)
|
90 |
* LSH relies on random projections, thus inference may differ slightly with different seeds
|
91 |
* Additional parameters:
|
92 |
* lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
|
93 |
+
* `sparsity_type="stride"`, use a striding mecanism per head
|
94 |
* Each head will use different tokens strided by sparsify_factor
|
95 |
* Not recommended if sparsify_factor > num_heads
|
96 |
+
* `sparsity_type="block_stride"`, use a striding mecanism per head
|
97 |
* Each head will use block of tokens strided by sparsify_factor
|
98 |
* Not recommended if sparsify_factor > num_heads
|
99 |
|
modeling_lsg_albert.py
CHANGED
@@ -53,16 +53,16 @@ class LSGAlbertConfig(AlbertConfig):
|
|
53 |
self.sparsity_factor = sparsity_factor
|
54 |
self.sparsity_type = sparsity_type
|
55 |
|
56 |
-
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
57 |
logger.warning(
|
58 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
59 |
setting sparsity_type=None, computation will skip sparse attention")
|
60 |
self.sparsity_type = None
|
61 |
|
62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
63 |
-
if self.sparsity_factor > self.
|
64 |
logger.warning(
|
65 |
-
"[WARNING CONFIG]: sparsity_factor >
|
66 |
)
|
67 |
|
68 |
if self.num_global_tokens < 1:
|
@@ -463,7 +463,7 @@ class LSGAlbertEmbeddings(AlbertEmbeddings):
|
|
463 |
return embeddings
|
464 |
|
465 |
|
466 |
-
class
|
467 |
'''
|
468 |
Compute local attention with overlapping blocs
|
469 |
Use global attention for tokens with highest norm
|
@@ -502,15 +502,16 @@ class LSGAttention(BaseSelfAttention):
|
|
502 |
"lsh": self.get_sparse_tokens_with_lsh,
|
503 |
"stride": self.get_sparse_tokens_with_stride,
|
504 |
"block_stride": self.get_sparse_tokens_with_block_stride,
|
|
|
505 |
}
|
506 |
|
507 |
self.sparsity_type = config.sparsity_type
|
508 |
-
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
|
509 |
|
510 |
if config.sparsity_type == "lsh":
|
511 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
512 |
|
513 |
-
def get_sparse_tokens_with_norm(self, keys, values, mask):
|
514 |
|
515 |
if self.sparsity_factor == 1:
|
516 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -538,7 +539,7 @@ class LSGAttention(BaseSelfAttention):
|
|
538 |
|
539 |
return keys, values, mask
|
540 |
|
541 |
-
def get_sparse_tokens_with_pooling(self, keys, values, mask):
|
542 |
|
543 |
if self.sparsity_factor == 1:
|
544 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -561,7 +562,7 @@ class LSGAttention(BaseSelfAttention):
|
|
561 |
mask *= torch.finfo(mask.dtype).min
|
562 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
563 |
|
564 |
-
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
565 |
|
566 |
if self.sparsity_factor == 1:
|
567 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -577,7 +578,7 @@ class LSGAttention(BaseSelfAttention):
|
|
577 |
|
578 |
return keys, values, mask
|
579 |
|
580 |
-
def get_sparse_tokens_with_block_stride(self, keys, values, mask):
|
581 |
|
582 |
if self.sparsity_factor == 1:
|
583 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
@@ -597,11 +598,14 @@ class LSGAttention(BaseSelfAttention):
|
|
597 |
|
598 |
return keys, values, mask
|
599 |
|
600 |
-
def get_sparse_tokens_with_lsh(self, keys, values, mask):
|
601 |
|
602 |
if self.sparsity_factor == 1:
|
603 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
604 |
|
|
|
|
|
|
|
605 |
block_size = min(self.block_size, self.sparse_block_size)
|
606 |
keys = self.chunk(keys, block_size)
|
607 |
values = self.chunk(values, block_size)
|
@@ -649,6 +653,29 @@ class LSGAttention(BaseSelfAttention):
|
|
649 |
|
650 |
return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
|
651 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
def forward(
|
653 |
self,
|
654 |
hidden_states,
|
@@ -720,7 +747,7 @@ class LSGAttention(BaseSelfAttention):
|
|
720 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
721 |
|
722 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
723 |
-
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
|
724 |
|
725 |
# Expand masks on heads
|
726 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
@@ -757,7 +784,7 @@ class LSGAlbertLayer(AlbertLayer):
|
|
757 |
def __init__(self, config):
|
758 |
super().__init__(config)
|
759 |
|
760 |
-
self.attention =
|
761 |
|
762 |
|
763 |
class LSGAlbertLayerGroup(AlbertLayerGroup):
|
|
|
53 |
self.sparsity_factor = sparsity_factor
|
54 |
self.sparsity_type = sparsity_type
|
55 |
|
56 |
+
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]:
|
57 |
logger.warning(
|
58 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \
|
59 |
setting sparsity_type=None, computation will skip sparse attention")
|
60 |
self.sparsity_type = None
|
61 |
|
62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
63 |
+
if self.sparsity_factor > self.num_attention_heads:
|
64 |
logger.warning(
|
65 |
+
"[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity"
|
66 |
)
|
67 |
|
68 |
if self.num_global_tokens < 1:
|
|
|
463 |
return embeddings
|
464 |
|
465 |
|
466 |
+
class LSGSelfAttention(BaseSelfAttention):
|
467 |
'''
|
468 |
Compute local attention with overlapping blocs
|
469 |
Use global attention for tokens with highest norm
|
|
|
502 |
"lsh": self.get_sparse_tokens_with_lsh,
|
503 |
"stride": self.get_sparse_tokens_with_stride,
|
504 |
"block_stride": self.get_sparse_tokens_with_block_stride,
|
505 |
+
"bos_pooling": self.get_sparse_tokens_with_bos_pooling
|
506 |
}
|
507 |
|
508 |
self.sparsity_type = config.sparsity_type
|
509 |
+
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None))
|
510 |
|
511 |
if config.sparsity_type == "lsh":
|
512 |
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
|
513 |
|
514 |
+
def get_sparse_tokens_with_norm(self, queries, keys, values, mask):
|
515 |
|
516 |
if self.sparsity_factor == 1:
|
517 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
539 |
|
540 |
return keys, values, mask
|
541 |
|
542 |
+
def get_sparse_tokens_with_pooling(self, queries, keys, values, mask):
|
543 |
|
544 |
if self.sparsity_factor == 1:
|
545 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
562 |
mask *= torch.finfo(mask.dtype).min
|
563 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
564 |
|
565 |
+
def get_sparse_tokens_with_stride(self, queries, keys, values, mask):
|
566 |
|
567 |
if self.sparsity_factor == 1:
|
568 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
578 |
|
579 |
return keys, values, mask
|
580 |
|
581 |
+
def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask):
|
582 |
|
583 |
if self.sparsity_factor == 1:
|
584 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
|
|
598 |
|
599 |
return keys, values, mask
|
600 |
|
601 |
+
def get_sparse_tokens_with_lsh(self, queries, keys, values, mask):
|
602 |
|
603 |
if self.sparsity_factor == 1:
|
604 |
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
605 |
|
606 |
+
if self.sparsity_factor == self.sparse_block_size:
|
607 |
+
return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask)
|
608 |
+
|
609 |
block_size = min(self.block_size, self.sparse_block_size)
|
610 |
keys = self.chunk(keys, block_size)
|
611 |
values = self.chunk(values, block_size)
|
|
|
653 |
|
654 |
return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
|
655 |
|
656 |
+
def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask):
|
657 |
+
|
658 |
+
if self.sparsity_factor == 1:
|
659 |
+
return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
|
660 |
+
|
661 |
+
queries = queries.unsqueeze(-3)
|
662 |
+
mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2)
|
663 |
+
keys = self.chunk(keys, self.sparsity_factor)
|
664 |
+
values = self.chunk(values, self.sparsity_factor)
|
665 |
+
|
666 |
+
n, h, b, t, d = keys.size()
|
667 |
+
scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d)
|
668 |
+
if mask is not None:
|
669 |
+
scores = scores + mask
|
670 |
+
|
671 |
+
scores = torch.softmax(scores, dim=-1)
|
672 |
+
keys = scores @ keys
|
673 |
+
values = scores @ values
|
674 |
+
mask = mask.mean(dim=-1)
|
675 |
+
mask[mask != torch.finfo(mask.dtype).min] = 0
|
676 |
+
|
677 |
+
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
678 |
+
|
679 |
def forward(
|
680 |
self,
|
681 |
hidden_states,
|
|
|
747 |
sparse_key, sparse_value, sparse_mask = (None, None, None)
|
748 |
|
749 |
if self.sparse_block_size and self.sparsity_factor > 0:
|
750 |
+
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask)
|
751 |
|
752 |
# Expand masks on heads
|
753 |
attention_mask = attention_mask.expand(-1, h, -1, -1)
|
|
|
784 |
def __init__(self, config):
|
785 |
super().__init__(config)
|
786 |
|
787 |
+
self.attention = LSGSelfAttention(config)
|
788 |
|
789 |
|
790 |
class LSGAlbertLayerGroup(AlbertLayerGroup):
|