ccdv commited on
Commit
2d1da30
·
1 Parent(s): 96d2353
README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - albert
4
+ - long context
5
+ language:
6
+ - en
7
+ pipeline_tag: fill-mask
8
+ ---
9
+
10
+ # LSG model
11
+ **Transformers >= 4.18.0**\
12
+ **This model relies on a custom modeling file, you need to add trust_remote_code=True**\
13
+ **See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
14
+
15
+ * [Usage](#usage)
16
+ * [Parameters](#parameters)
17
+ * [Sparse selection type](#sparse-selection-type)
18
+ * [Tasks](#tasks)
19
+
20
+ This model is adapted from [AlBERT-base-v2](https://huggingface.co/albert-base-v2) without additional pretraining. It uses the same number of parameters/layers and the same tokenizer.
21
+
22
+
23
+ This model can handle long sequences but faster and more efficiently than Longformer (LED) or BigBird (Pegasus) from the hub and relies on Local + Sparse + Global attention (LSG).
24
+
25
+ The model requires sequences whose length is a multiple of the block size. The model is "adaptive" and automatically pads the sequences if needed (adaptive=True in config). It is however recommended, thanks to the tokenizer, to truncate the inputs (truncation=True) and optionally to pad with a multiple of the block size (pad_to_multiple_of=...). \
26
+
27
+ Implemented in PyTorch.
28
+
29
+ ![attn](attn.png)
30
+
31
+ ## Usage
32
+ The model relies on a custom modeling file, you need to add trust_remote_code=True to use it.
33
+
34
+ ```python:
35
+ from transformers import AutoModel, AutoTokenizer
36
+
37
+ model = AutoModel.from_pretrained("ccdv/lsg-albert-base-v2-4096", trust_remote_code=True)
38
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-albert-base-v2-4096")
39
+ ```
40
+
41
+ ## Parameters
42
+ You can change various parameters like :
43
+ * the number of global tokens (num_global_tokens=1)
44
+ * local block size (block_size=128)
45
+ * sparse block size (sparse_block_size=128)
46
+ * sparsity factor (sparsity_factor=2)
47
+ * mask_first_token (mask first token since it is redundant with the first global token)
48
+ * see config.json file
49
+
50
+ Default parameters work well in practice. If you are short on memory, reduce block sizes, increase sparsity factor and remove dropout in the attention score matrix.
51
+
52
+ ```python:
53
+ from transformers import AutoModel
54
+
55
+ model = AutoModel.from_pretrained("ccdv/lsg-albert-base-v2-4096",
56
+ trust_remote_code=True,
57
+ num_global_tokens=16,
58
+ block_size=64,
59
+ sparse_block_size=64,
60
+ attention_probs_dropout_prob=0.0
61
+ sparsity_factor=4,
62
+ sparsity_type="none",
63
+ mask_first_token=True
64
+ )
65
+ ```
66
+
67
+ ## Sparse selection type
68
+
69
+ There are 5 different sparse selection patterns. The best type is task dependent. \
70
+ Note that for sequences with length < 2*block_size, the type has no effect.
71
+
72
+ * sparsity_type="norm", select highest norm tokens
73
+ * Works best for a small sparsity_factor (2 to 4)
74
+ * Additional parameters:
75
+ * None
76
+ * sparsity_type="pooling", use average pooling to merge tokens
77
+ * Works best for a small sparsity_factor (2 to 4)
78
+ * Additional parameters:
79
+ * None
80
+ * sparsity_type="lsh", use the LSH algorithm to cluster similar tokens
81
+ * Works best for a large sparsity_factor (4+)
82
+ * LSH relies on random projections, thus inference may differ slightly with different seeds
83
+ * Additional parameters:
84
+ * lsg_num_pre_rounds=1, pre merge tokens n times before computing centroids
85
+ * sparsity_type="stride", use a striding mecanism per head
86
+ * Each head will use different tokens strided by sparsify_factor
87
+ * Not recommended if sparsify_factor > num_heads
88
+ * sparsity_type="block_stride", use a striding mecanism per head
89
+ * Each head will use block of tokens strided by sparsify_factor
90
+ * Not recommended if sparsify_factor > num_heads
91
+
92
+ ## Tasks
93
+ Seq2Seq example for summarization:
94
+ ```python:
95
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
96
+
97
+ model = AutoModelForSeq2SeqLM.from_pretrained("ccdv/lsg-albert-base-v2-4096",
98
+ trust_remote_code=True,
99
+ pass_global_tokens_to_decoder=True, # Pass encoder global tokens to decoder
100
+ )
101
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-albert-base-v2-4096")
102
+
103
+ SENTENCE = "This is a test sequence to test the model. " * 300
104
+ token_ids = tokenizer(
105
+ SENTENCE,
106
+ return_tensors="pt",
107
+ padding="max_length", # Optional but recommended
108
+ truncation=True # Optional but recommended
109
+ )
110
+ output = model(**token_ids)
111
+ ```
112
+
113
+
114
+ Classification example:
115
+ ```python:
116
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
117
+
118
+ model = AutoModelForSequenceClassification.from_pretrained("ccdv/lsg-albert-base-v2-4096",
119
+ trust_remote_code=True,
120
+ pass_global_tokens_to_decoder=True, # Pass encoder global tokens to decoder
121
+ )
122
+ tokenizer = AutoTokenizer.from_pretrained("ccdv/lsg-albert-base-v2-4096")
123
+
124
+ SENTENCE = "This is a test sequence to test the model. " * 300
125
+ token_ids = tokenizer(
126
+ SENTENCE,
127
+ return_tensors="pt",
128
+ #pad_to_multiple_of=... # Optional
129
+ truncation=True
130
+ )
131
+ output = model(**token_ids)
132
+
133
+ > SequenceClassifierOutput(loss=None, logits=tensor([[-0.3051, -0.1762]], grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)
134
+ ```
135
+
136
+ **AlBERT**
137
+ ```
138
+ @article{DBLP:journals/corr/abs-1909-11942,
139
+ author = {Zhenzhong Lan and
140
+ Mingda Chen and
141
+ Sebastian Goodman and
142
+ Kevin Gimpel and
143
+ Piyush Sharma and
144
+ Radu Soricut},
145
+ title = {{ALBERT:} {A} Lite {BERT} for Self-supervised Learning of Language
146
+ Representations},
147
+ journal = {CoRR},
148
+ volume = {abs/1909.11942},
149
+ year = {2019},
150
+ url = {http://arxiv.org/abs/1909.11942},
151
+ archivePrefix = {arXiv},
152
+ eprint = {1909.11942},
153
+ timestamp = {Fri, 27 Sep 2019 13:04:21 +0200},
154
+ biburl = {https://dblp.org/rec/journals/corr/abs-1909-11942.bib},
155
+ bibsource = {dblp computer science bibliography, https://dblp.org}
156
+ }
157
+ ```
attn.png ADDED
config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "roberta-test",
3
+ "adaptive": true,
4
+ "architectures": [
5
+ "LSGAlbertForMaskedLM"
6
+ ],
7
+ "attention_probs_dropout_prob": 0,
8
+ "auto_map": {
9
+ "AutoConfig": "modeling_lsg_albert.LSGAlbertConfig",
10
+ "AutoModel": "modeling_lsg_albert.LSGAlbertModel",
11
+ "AutoModelForMaskedLM": "modeling_lsg_albert.LSGAlbertForMaskedLM",
12
+ "AutoModelForMultipleChoice": "modeling_lsg_albert.LSGAlbertForMultipleChoice",
13
+ "AutoModelForPreTraining": "modeling_lsg_albert.LSGAlbertForPreTraining",
14
+ "AutoModelForQuestionAnswering": "modeling_lsg_albert.LSGAlbertForQuestionAnswering",
15
+ "AutoModelForSequenceClassification": "modeling_lsg_albert.LSGAlbertForSequenceClassification",
16
+ "AutoModelForTokenClassification": "modeling_lsg_albert.LSGAlbertForTokenClassification"
17
+ },
18
+ "base_model_prefix": "lsg",
19
+ "block_size": 128,
20
+ "bos_token_id": 2,
21
+ "classifier_dropout_prob": 0.1,
22
+ "down_scale_factor": 1,
23
+ "embedding_size": 128,
24
+ "eos_token_id": 3,
25
+ "gap_size": 0,
26
+ "hidden_act": "gelu_new",
27
+ "hidden_dropout_prob": 0,
28
+ "hidden_size": 768,
29
+ "initializer_range": 0.02,
30
+ "inner_group_num": 1,
31
+ "intermediate_size": 3072,
32
+ "layer_norm_eps": 1e-12,
33
+ "lsh_num_pre_rounds": 1,
34
+ "mask_first_token": true,
35
+ "max_position_embeddings": 4096,
36
+ "model_type": "albert",
37
+ "net_structure_type": 0,
38
+ "num_attention_heads": 12,
39
+ "num_global_tokens": 1,
40
+ "num_hidden_groups": 1,
41
+ "num_hidden_layers": 12,
42
+ "num_memory_blocks": 0,
43
+ "pad_token_id": 0,
44
+ "pool_with_global": true,
45
+ "position_embedding_type": "absolute",
46
+ "sparse_block_size": 128,
47
+ "sparsity_factor": 2,
48
+ "sparsity_type": "norm",
49
+ "torch_dtype": "float32",
50
+ "transformers_version": "4.20.1",
51
+ "type_vocab_size": 2,
52
+ "vocab_size": 30000
53
+ }
modeling_lsg_albert.py ADDED
@@ -0,0 +1,1014 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from logging import warn
2
+ from transformers.models.albert.modeling_albert import *
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers.models.albert.configuration_albert import AlbertConfig
6
+ import sys
7
+
8
+ AUTO_MAP = {
9
+ "AutoModel": "modeling_lsg_albert.LSGAlbertModel",
10
+ "AutoModelForMaskedLM": "modeling_lsg_albert.LSGAlbertForMaskedLM",
11
+ "AutoModelForPreTraining": "modeling_lsg_albert.LSGAlbertForPreTraining",
12
+ "AutoModelForMultipleChoice": "modeling_lsg_albert.LSGAlbertForMultipleChoice",
13
+ "AutoModelForQuestionAnswering": "modeling_lsg_albert.LSGAlbertForQuestionAnswering",
14
+ "AutoModelForSequenceClassification": "modeling_lsg_albert.LSGAlbertForSequenceClassification",
15
+ "AutoModelForTokenClassification": "modeling_lsg_albert.LSGAlbertForTokenClassification"
16
+ }
17
+
18
+ class LSGAlbertConfig(AlbertConfig):
19
+ """
20
+ This class overrides :class:`~transformers.LSGAlbertConfig`. Please check the superclass for the appropriate
21
+ documentation alongside usage examples.
22
+ """
23
+
24
+ base_model_prefix = "lsg"
25
+ model_type = "albert"
26
+
27
+ def __init__(
28
+ self,
29
+ adaptive=True,
30
+ base_model_prefix="lsg",
31
+ block_size=128,
32
+ lsh_num_pre_rounds=1,
33
+ mask_first_token=False,
34
+ num_global_tokens=1,
35
+ pool_with_global=True,
36
+ sparse_block_size=128,
37
+ sparsity_factor=2,
38
+ sparsity_type="norm",
39
+ **kwargs
40
+ ):
41
+ """Constructs LSGAlbertConfig."""
42
+ super().__init__(**kwargs)
43
+
44
+ self.adaptive = adaptive
45
+ self.auto_map = AUTO_MAP
46
+ self.base_model_prefix = base_model_prefix
47
+ self.block_size = block_size
48
+ self.lsh_num_pre_rounds = lsh_num_pre_rounds
49
+ self.mask_first_token = mask_first_token
50
+ self.num_global_tokens = num_global_tokens
51
+ self.pool_with_global = pool_with_global
52
+ self.sparse_block_size = sparse_block_size
53
+ self.sparsity_factor = sparsity_factor
54
+ self.sparsity_type = sparsity_type
55
+
56
+ if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
+ logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
59
+ self.sparsity_type = None
60
+
61
+ if self.sparsity_type in ["stride", "block_stride"]:
62
+ if self.sparsity_factor > self.encoder_attention_heads:
63
+ logger.warning(
64
+ "[WARNING CONFIG]: sparsity_factor > encoder_attention_heads is not recommended for stride/block_stride sparsity"
65
+ )
66
+
67
+ if self.num_global_tokens < 1:
68
+ logger.warning(
69
+ "[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1"
70
+ )
71
+ self.num_global_tokens = 1
72
+ elif self.num_global_tokens > 512:
73
+ logger.warning(
74
+ "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
75
+ )
76
+ self.num_global_tokens = 512
77
+
78
+ if self.sparsity_factor > 0:
79
+ assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
80
+ assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
81
+
82
+
83
+ class BaseSelfAttention(nn.Module):
84
+
85
+ def init_modules(self, config):
86
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
87
+ config, "embedding_size"
88
+ ):
89
+ raise ValueError(
90
+ "The hidden size (%d) is not a multiple of the number of attention "
91
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
92
+ )
93
+
94
+ self.num_attention_heads = config.num_attention_heads
95
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
96
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
97
+
98
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
99
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
100
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
101
+
102
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
103
+
104
+ def transpose_for_scores(self, x):
105
+ new_x_shape = x.size()[:-1] + (
106
+ self.num_attention_heads,
107
+ self.attention_head_size,
108
+ )
109
+ x = x.view(*new_x_shape)
110
+ return x.permute(0, 2, 1, 3)
111
+
112
+ def reshape_output(self, context_layer):
113
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
114
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
115
+ return context_layer.view(*new_context_layer_shape)
116
+
117
+ def project_QKV(self, hidden_states):
118
+
119
+ query_layer = self.transpose_for_scores(self.query(hidden_states))
120
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
121
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
122
+ return query_layer, key_layer, value_layer
123
+
124
+
125
+ class BaseAttentionProduct(nn.Module):
126
+
127
+ def __init__(self, config):
128
+ """
129
+ Compute attention: softmax(Q @ K.T) @ V
130
+ """
131
+ super().__init__()
132
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
133
+
134
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
135
+
136
+ d = query_layer.shape[-1]
137
+
138
+ # Take the dot product between "query" and "key" to get the raw attention scores.
139
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
140
+
141
+ del query_layer
142
+ del key_layer
143
+
144
+ if attention_mask is not None:
145
+ # Apply the attention mask is (precomputed for all layers in AlbertModel forward() function)
146
+ attention_scores = attention_scores + attention_mask
147
+ del attention_mask
148
+
149
+ # Normalize the attention scores to probabilities.
150
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
151
+
152
+ # This is actually dropping out entire tokens to attend to, which might
153
+ # seem a bit unusual, but is taken from the original Transformer paper.
154
+ context_layer = self.dropout(attention_probs) @ value_layer
155
+
156
+ return context_layer
157
+
158
+
159
+ class CausalAttentionProduct(nn.Module):
160
+
161
+ def __init__(self, config):
162
+ """
163
+ Compute attention: softmax(Q @ K.T) @ V
164
+ """
165
+ super().__init__()
166
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
167
+ self.block_size = config.block_size
168
+
169
+ def forward(self, query_layer, key_layer, value_layer, attention_mask=None, causal_shape=None):
170
+
171
+ d = query_layer.shape[-1]
172
+
173
+ # Take the dot product between "query" and "key" to get the raw attention scores.
174
+ attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d)
175
+
176
+ del query_layer
177
+ del key_layer
178
+
179
+ if attention_mask is not None:
180
+ # Apply the attention mask is (precomputed for all layers in AlbertModel forward() function)
181
+ attention_scores = attention_scores + attention_mask
182
+
183
+ # Add causal mask
184
+ causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape
185
+ causal_mask = torch.tril(
186
+ torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype),
187
+ diagonal=-1
188
+ )
189
+ causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
190
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
191
+
192
+ del attention_mask
193
+
194
+ # Normalize the attention scores to probabilities.
195
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
196
+
197
+ # This is actually dropping out entire tokens to attend to, which might
198
+ # seem a bit unusual, but is taken from the original Transformer paper.
199
+ context_layer = self.dropout(attention_probs) @ value_layer
200
+
201
+ return context_layer
202
+
203
+
204
+ class LSGAttentionProduct(nn.Module):
205
+
206
+ def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4, is_causal=False):
207
+ """
208
+ Compute block or overlapping blocks attention products
209
+ """
210
+ super().__init__()
211
+
212
+ self.block_size = block_size
213
+ self.sparse_block_size = sparse_block_size
214
+ self.sparsity_factor = sparsity_factor
215
+ self.is_causal = is_causal
216
+
217
+ if self.block_size is None:
218
+ self.block_size = config.block_size
219
+
220
+ if self.sparse_block_size is None:
221
+ self.sparse_block_size = config.sparse_block_size
222
+
223
+ # Shape of blocks
224
+ self.local_shapes = (self.block_size*3, self.block_size)
225
+ if self.sparse_block_size and self.sparsity_factor > 0:
226
+ self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor)
227
+
228
+ if is_causal:
229
+ self.attention = CausalAttentionProduct(config)
230
+ else:
231
+ self.attention = BaseAttentionProduct(config)
232
+
233
+ def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False):
234
+
235
+ # Build local tokens
236
+ local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask)
237
+ del hidden_states
238
+
239
+ # Build sparse tokens
240
+ if sparse_hidden_states is not None:
241
+ sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask)
242
+
243
+ return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states)
244
+
245
+ def forward(
246
+ self,
247
+ query_layer,
248
+ key_layer,
249
+ value_layer,
250
+ attention_mask=None,
251
+ sparse_key=None,
252
+ sparse_value=None,
253
+ sparse_mask=None,
254
+ global_key=None,
255
+ global_value=None,
256
+ global_mask=None
257
+ ):
258
+
259
+ # Input batch, heads, length, hidden_size
260
+ n, h, t, d = query_layer.size()
261
+ n_blocks = t // self.block_size
262
+ assert t % self.block_size == 0
263
+
264
+ key_layer = self.build_lsg_inputs(
265
+ key_layer,
266
+ sparse_key,
267
+ global_key
268
+ )
269
+ del sparse_key
270
+ del global_key
271
+
272
+ value_layer = self.build_lsg_inputs(
273
+ value_layer,
274
+ sparse_value,
275
+ global_value
276
+ )
277
+ del sparse_value
278
+ del global_value
279
+
280
+ attention_mask = self.build_lsg_inputs(
281
+ attention_mask,
282
+ sparse_mask,
283
+ global_mask.transpose(-1, -2),
284
+ is_attn_mask=True
285
+ ).transpose(-1, -2)
286
+ del sparse_mask
287
+ del global_mask
288
+
289
+ # expect (..., t, d) shape
290
+ # Compute attention
291
+ context_layer = self.attention(
292
+ query_layer=self.chunk(query_layer, n_blocks),
293
+ key_layer=key_layer,
294
+ value_layer=value_layer,
295
+ attention_mask=attention_mask
296
+ )
297
+
298
+ return context_layer.reshape(n, h, -1, d)
299
+
300
+ def reshape_to_local_block(self, hidden_states, is_attn_mask=False):
301
+
302
+ size, step = self.local_shapes
303
+ s = (size - step) // 2
304
+
305
+ # Pad before block reshaping
306
+ if is_attn_mask:
307
+ pad_value = torch.finfo(hidden_states.dtype).min
308
+ hidden_states = hidden_states.transpose(-1, -2)
309
+ else:
310
+ pad_value = 0
311
+
312
+ hidden_states = torch.nn.functional.pad(
313
+ hidden_states.transpose(-1, -2),
314
+ pad=(s, s),
315
+ value=pad_value
316
+ ).transpose(-1, -2)
317
+
318
+ # Make blocks
319
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
320
+
321
+ # Skip third block if causal
322
+ if self.is_causal:
323
+ return hidden_states[..., :size*2//3, :]
324
+
325
+ return hidden_states
326
+
327
+ def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False):
328
+
329
+ size, step = self.sparse_shapes
330
+
331
+ # In case of odd case
332
+ odd_offset = (step % 2)
333
+
334
+ # n, h, t, d*2 + 1
335
+ size = size*2
336
+ s = (size - step) // 2 + odd_offset
337
+
338
+ # Pad before block reshaping
339
+ if is_attn_mask:
340
+ pad_value = torch.finfo(hidden_states.dtype).min
341
+ hidden_states = hidden_states.transpose(-1, -2)
342
+ else:
343
+ pad_value = 0
344
+
345
+ hidden_states = torch.nn.functional.pad(
346
+ hidden_states.transpose(-1, -2),
347
+ pad=(s, s),
348
+ value=pad_value
349
+ ).transpose(-1, -2)
350
+
351
+ # Make blocks
352
+ hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2)
353
+
354
+ # Fix case where block_size == sparsify_factor
355
+ if odd_offset:
356
+ hidden_states = hidden_states[..., :-1, :, :]
357
+
358
+ # Indexes for selection
359
+ u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset
360
+ s = self.sparse_block_size
361
+
362
+ # Skip right block if causal
363
+ if self.is_causal:
364
+ return hidden_states[..., u-s:u, :]
365
+
366
+ u_ = u + odd_offset
367
+ return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2)
368
+
369
+ def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2):
370
+
371
+ n, h, b, t, d = x_local.size()
372
+ x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1)
373
+ if x_sparse is not None:
374
+ return torch.cat([x_global, x_sparse, x_local], dim=dim)
375
+ return torch.cat([x_global, x_local], dim=dim)
376
+
377
+ def chunk(self, x, n_blocks):
378
+
379
+ t, d = x.size()[-2:]
380
+ return x.reshape(*x.size()[:-2], n_blocks, -1, d)
381
+
382
+
383
+ class LSGAlbertEmbeddings(AlbertEmbeddings):
384
+ """
385
+ Construct the embeddings from word, position and token_type embeddings.
386
+ """
387
+
388
+ def __init__(self, config):
389
+ super().__init__(config)
390
+
391
+ self.num_global_tokens = config.num_global_tokens
392
+
393
+ # Hardcoded but partially trained
394
+ self.global_embeddings = nn.Embedding(512, embedding_dim=config.embedding_size, )
395
+
396
+ self.block_size = config.block_size
397
+
398
+ def forward(
399
+ self,
400
+ input_ids=None,
401
+ token_type_ids=None,
402
+ position_ids=None,
403
+ inputs_embeds=None,
404
+ past_key_values_length=0,
405
+ ) -> torch.Tensor:
406
+ if input_ids is not None:
407
+ input_shape = input_ids.size()
408
+ else:
409
+ input_shape = inputs_embeds.size()[:-1]
410
+
411
+ seq_length = input_shape[1]
412
+
413
+ if position_ids is None:
414
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
415
+
416
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
417
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
418
+ # issue #5664
419
+ if token_type_ids is None:
420
+ if hasattr(self, "token_type_ids"):
421
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
422
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
423
+ token_type_ids = buffered_token_type_ids_expanded
424
+ else:
425
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
426
+
427
+ if inputs_embeds is None:
428
+ inputs_embeds = self.word_embeddings(input_ids)
429
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
430
+
431
+ embeddings = inputs_embeds + token_type_embeddings
432
+ if self.position_embedding_type == "absolute":
433
+ position_embeddings = self.position_embeddings(position_ids)
434
+ embeddings += position_embeddings
435
+
436
+ n, t, d = embeddings.size()
437
+
438
+ # Add global_tokens
439
+ indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1)
440
+ global_embeddings = self.global_embeddings(indexes)
441
+ embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2)
442
+
443
+
444
+ embeddings = self.LayerNorm(embeddings)
445
+ embeddings = self.dropout(embeddings)
446
+ return embeddings
447
+
448
+
449
+ class LSGAttention(BaseSelfAttention):
450
+ '''
451
+ Compute local attention with overlapping blocs
452
+ Use global attention for tokens with highest norm
453
+ '''
454
+ def __init__(self, config):
455
+ super().__init__()
456
+
457
+ self.init_modules(config)
458
+
459
+ self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
460
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
461
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
462
+
463
+ self.block_size = config.block_size
464
+ self.sparse_block_size = config.sparse_block_size
465
+ self.num_global_tokens = config.num_global_tokens
466
+ self.sparsity_factor = config.sparsity_factor
467
+ self.is_causal = config.is_decoder
468
+ self.is_decoder = config.is_decoder
469
+
470
+ self.attention = LSGAttentionProduct(
471
+ config,
472
+ block_size=config.block_size,
473
+ sparse_block_size=config.sparse_block_size,
474
+ sparsity_factor=self.sparsity_factor,
475
+ is_causal=self.is_causal
476
+ )
477
+
478
+ if self.is_causal:
479
+ self.causal_attention = CausalAttentionProduct(config)
480
+ self.full_attention = BaseAttentionProduct(config)
481
+
482
+ sparse_functions = {
483
+ "norm": self.get_sparse_tokens_with_norm,
484
+ "pooling": self.get_sparse_tokens_with_pooling,
485
+ "lsh": self.get_sparse_tokens_with_lsh,
486
+ "stride": self.get_sparse_tokens_with_stride,
487
+ "block_stride": self.get_sparse_tokens_with_block_stride,
488
+ }
489
+
490
+ self.sparsity_type = config.sparsity_type
491
+ self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda x, y, z: (None, None, None))
492
+
493
+ if config.sparsity_type == "lsh":
494
+ self.lsh_num_pre_rounds = config.lsh_num_pre_rounds
495
+
496
+ def get_sparse_tokens_with_norm(self, keys, values, mask):
497
+
498
+ if self.sparsity_factor == 1:
499
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
500
+
501
+ with torch.no_grad():
502
+
503
+ block_size = min(self.block_size, self.sparse_block_size)
504
+ key_norm = keys.detach().norm(dim=-1, keepdim=True)
505
+ key_norm = key_norm * ~mask.transpose(-1, -2).bool()
506
+ key_norm = self.chunk(key_norm, block_size)
507
+
508
+ n, h, b, t, d = key_norm.size()
509
+
510
+ idx = key_norm.argsort(dim=-2)
511
+ del key_norm
512
+ idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1)
513
+
514
+ split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor)
515
+ sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1)
516
+
517
+ d = keys.size()[-1]
518
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
519
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
520
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
521
+
522
+ return keys, values, mask
523
+
524
+ def get_sparse_tokens_with_pooling(self, keys, values, mask):
525
+
526
+ if self.sparsity_factor == 1:
527
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
528
+
529
+ keys = self.chunk(keys, self.sparsity_factor)
530
+ values = self.chunk(values, self.sparsity_factor)
531
+
532
+ n, h, b, t, d = keys.size()
533
+ mask = mask.reshape(n, 1, b, 1, t)
534
+ mask = ~mask.transpose(-1, -2).bool()
535
+
536
+ keys = keys * mask
537
+ values = values * mask
538
+
539
+ mask = mask.sum(dim=-2)
540
+ keys = keys.sum(dim=-2) / (mask + 1e-6)
541
+ values = values.sum(dim=-2) / (mask + 1e-6)
542
+
543
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
544
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
545
+
546
+ def get_sparse_tokens_with_stride(self, keys, values, mask):
547
+
548
+ if self.sparsity_factor == 1:
549
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
550
+
551
+ n, h, t, d = keys.size()
552
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor
553
+ sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1)
554
+ sparse_idx = sparse_idx.expand(n, h, -1, 1)
555
+
556
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
557
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
558
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
559
+
560
+ return keys, values, mask
561
+
562
+ def get_sparse_tokens_with_block_stride(self, keys, values, mask):
563
+
564
+ if self.sparsity_factor == 1:
565
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
566
+
567
+ n, h, t, d = keys.size()
568
+
569
+ t, b = self.block_size, t // self.block_size
570
+ sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device)
571
+ sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor)
572
+ sparse_idx = (sparse_idx % t)
573
+ sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t
574
+ sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1)
575
+
576
+ keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
577
+ values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d))
578
+ mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2)
579
+
580
+ return keys, values, mask
581
+
582
+ def get_sparse_tokens_with_lsh(self, keys, values, mask):
583
+
584
+ if self.sparsity_factor == 1:
585
+ return keys, values, mask.expand(-1, keys.size()[1], -1, -1)
586
+
587
+ block_size = min(self.block_size, self.sparse_block_size)
588
+ keys = self.chunk(keys, block_size)
589
+ values = self.chunk(values, block_size)
590
+
591
+ n, h, b, t, d = keys.size()
592
+ mask = mask.reshape(n, 1, b, 1, t)
593
+ mask = ~mask.transpose(-1, -2).bool()
594
+
595
+ keys = keys * mask
596
+ values = values * mask
597
+ mask = mask.expand(-1, h, -1, -1, -1).float()
598
+
599
+ extra_factor = 1
600
+
601
+ for _ in range(self.lsh_num_pre_rounds):
602
+ keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor)
603
+
604
+ keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor)
605
+ keys /= mask + 1e-8
606
+ values /= mask + 1e-8
607
+
608
+ mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
609
+
610
+ return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
611
+
612
+ def lsh_round(self, keys, values, mask, output_size):
613
+
614
+ with torch.no_grad():
615
+
616
+ n_hashes = output_size // 2
617
+ n, h, b, t, d = keys.size()
618
+ binary_mask = mask.clamp(0, 1)
619
+
620
+ indexes = (torch.nn.functional.normalize(keys, dim=-1) * binary_mask) @ torch.randn(1, h, 1, d, n_hashes, device=keys.device)
621
+ indexes = torch.cat([indexes, -indexes], dim=-1).argmax(dim=-1, keepdim=True)
622
+
623
+ n, h, b, t, d = keys.size()
624
+
625
+ x_ = torch.zeros(n, h, b, output_size, d, device=keys.device)
626
+ mask_ = torch.zeros(n, h, b, output_size, 1, device=keys.device)
627
+ keys = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=keys)
628
+ values = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=values)
629
+ mask = torch.scatter_add(mask_, dim=-2, index=indexes, src=mask)
630
+
631
+ return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :]
632
+
633
+ def forward(
634
+ self,
635
+ hidden_states,
636
+ attention_mask=None,
637
+ head_mask=None,
638
+ encoder_hidden_states=None,
639
+ encoder_attention_mask=None,
640
+ past_key_value=None,
641
+ output_attentions=False,
642
+ ):
643
+
644
+ query_layer, key_layer, value_layer = self.project_QKV(hidden_states)
645
+ outputs = self.not_causal_forward(
646
+ query_layer,
647
+ key_layer,
648
+ value_layer,
649
+ attention_mask=attention_mask,
650
+ output_attentions=output_attentions
651
+ )
652
+
653
+ context = outputs[0]
654
+ context = self.dense(context)
655
+ context = self.output_dropout(context)
656
+ context = self.LayerNorm(context + hidden_states)
657
+
658
+ outputs = (context, ) + outputs[1:]
659
+
660
+ #if head_mask is not None:
661
+ # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
662
+ return outputs
663
+
664
+ def not_causal_forward(
665
+ self,
666
+ query_layer,
667
+ key_layer,
668
+ value_layer,
669
+ attention_mask=None,
670
+ output_attentions=False,
671
+ ):
672
+
673
+ n, h, t, d = query_layer.size()
674
+
675
+ # Cat global mask
676
+ attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0)
677
+
678
+ # Use normal attention if local attention covers every tokens
679
+ if t <= 2 * self.block_size + self.num_global_tokens:
680
+ context_layer = self.full_attention(
681
+ query_layer=query_layer,
682
+ key_layer=key_layer,
683
+ value_layer=value_layer,
684
+ attention_mask=attention_mask
685
+ )
686
+ return (self.reshape_output(context_layer), )
687
+
688
+ # Split input into global tokens and other tokens
689
+ split = (self.num_global_tokens, t - self.num_global_tokens)
690
+ global_query, query_layer = query_layer.split(split, dim=-2)
691
+
692
+ # Get global_attention
693
+ bos = self.full_attention(
694
+ query_layer=global_query,
695
+ key_layer=key_layer,
696
+ value_layer=value_layer,
697
+ attention_mask=attention_mask
698
+ )
699
+
700
+ # Split K Q M on global and non global
701
+ global_key, key_layer = key_layer.split(split, dim=-2)
702
+ global_value, value_layer = value_layer.split(split, dim=-2)
703
+ global_mask, attention_mask = attention_mask.split(split, dim=-1)
704
+
705
+ n, h, t, d = key_layer.size()
706
+
707
+ # Get sparse idx
708
+ sparse_key, sparse_value, sparse_mask = (None, None, None)
709
+
710
+ if self.sparse_block_size and self.sparsity_factor > 0:
711
+ sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(key_layer, value_layer, attention_mask)
712
+
713
+ # Expand masks on heads
714
+ attention_mask = attention_mask.expand(-1, h, -1, -1)
715
+ global_mask = global_mask.expand(-1, h, -1, -1)
716
+
717
+ # Compute dot product attention
718
+ context_layer = self.attention(
719
+ query_layer,
720
+ key_layer,
721
+ value_layer,
722
+ attention_mask,
723
+ sparse_key=sparse_key,
724
+ sparse_value=sparse_value,
725
+ sparse_mask=sparse_mask,
726
+ global_key=global_key,
727
+ global_value=global_value,
728
+ global_mask=global_mask
729
+ )
730
+
731
+ # Merge global and local-sparse tokens
732
+ context_layer = torch.cat([bos, context_layer], dim=-2)
733
+ context_layer = self.reshape_output(context_layer)
734
+
735
+ return (context_layer,)
736
+
737
+ def chunk(self, x, chunk_size):
738
+
739
+ n, h, t, d = x.size()
740
+ return x.reshape(n, h, -1, chunk_size, d)
741
+
742
+
743
+ class LSGAlbertLayer(AlbertLayer):
744
+
745
+ def __init__(self, config):
746
+ super().__init__(config)
747
+
748
+ self.attention = LSGAttention(config)
749
+
750
+
751
+ class LSGAlbertLayerGroup(AlbertLayerGroup):
752
+
753
+ def __init__(self, config):
754
+ nn.Module.__init__(self)
755
+
756
+ self.albert_layers = nn.ModuleList([LSGAlbertLayer(config) for _ in range(config.inner_group_num)])
757
+
758
+
759
+ class LSGAlbertTransformer(AlbertTransformer):
760
+
761
+ def __init__(self, config):
762
+ nn.Module.__init__(self)
763
+
764
+ self.config = config
765
+ self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
766
+ self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
767
+
768
+
769
+ class LSGAlbertPreTrainedModel(PreTrainedModel):
770
+ """
771
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
772
+ models.
773
+ """
774
+
775
+ config_class = LSGAlbertConfig
776
+ load_tf_weights = load_tf_weights_in_albert
777
+ base_model_prefix = "albert"
778
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
779
+
780
+ def _init_weights(self, module):
781
+ """Initialize the weights."""
782
+ if isinstance(module, nn.Linear):
783
+ # Slightly different from the TF version which uses truncated_normal for initialization
784
+ # cf https://github.com/pytorch/pytorch/pull/5617
785
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
786
+ if module.bias is not None:
787
+ module.bias.data.zero_()
788
+ elif isinstance(module, nn.Embedding):
789
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
790
+ if module.padding_idx is not None:
791
+ module.weight.data[module.padding_idx].zero_()
792
+ elif isinstance(module, nn.LayerNorm):
793
+ module.bias.data.zero_()
794
+ module.weight.data.fill_(1.0)
795
+
796
+
797
+ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
798
+
799
+ config_class = LSGAlbertConfig
800
+ base_model_prefix = "albert"
801
+
802
+ def __init__(self, config, add_pooling_layer=True):
803
+ AlbertPreTrainedModel.__init__(self, config)
804
+
805
+ assert hasattr(config, "num_global_tokens")
806
+ self.num_global_tokens = config.num_global_tokens
807
+ self.pad_idx = config.pad_token_id
808
+
809
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
810
+ self.block_size = config.block_size
811
+ self.adaptive = config.adaptive
812
+ self.mask_first_token = config.mask_first_token
813
+ self.pool_with_global = config.pool_with_global
814
+
815
+ self.config = config
816
+ self.embeddings = LSGAlbertEmbeddings(config)
817
+ self.encoder = LSGAlbertTransformer(config)
818
+ if add_pooling_layer:
819
+ self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
820
+ self.pooler_activation = nn.Tanh()
821
+ else:
822
+ self.pooler = None
823
+ self.pooler_activation = None
824
+
825
+ # Initialize weights and apply final processing
826
+ self.post_init()
827
+
828
+ def forward(
829
+ self,
830
+ input_ids=None,
831
+ attention_mask=None,
832
+ token_type_ids=None,
833
+ position_ids=None,
834
+ head_mask=None,
835
+ inputs_embeds=None,
836
+ output_attentions=None,
837
+ output_hidden_states=None,
838
+ return_dict=None,
839
+ ):
840
+
841
+ inputs_ = input_ids if input_ids is not None else inputs_embeds
842
+ n, t = inputs_.size()[:2]
843
+
844
+ if attention_mask is None:
845
+ attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
846
+ if self.mask_first_token:
847
+ attention_mask[:,0] = 0
848
+
849
+ b = self.block_size * 2
850
+ pad = t % self.block_size
851
+
852
+ # Check if t is multiple of block_size and pad
853
+ if self.adaptive and t > b and pad > 0:
854
+ pad_length = self.block_size - pad
855
+ if input_ids is not None:
856
+ input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
857
+ else:
858
+ inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
859
+
860
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
861
+
862
+ if token_type_ids is not None:
863
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
864
+ if position_ids is not None:
865
+ position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
866
+
867
+ n, t_ = attention_mask.size()
868
+
869
+ encoder_outputs = super().forward(
870
+ input_ids=input_ids,
871
+ attention_mask=attention_mask,
872
+ token_type_ids=token_type_ids,
873
+ position_ids=position_ids,
874
+ head_mask=head_mask,
875
+ inputs_embeds=inputs_embeds,
876
+ output_attentions=output_attentions,
877
+ output_hidden_states=output_hidden_states,
878
+ return_dict=return_dict
879
+ )
880
+
881
+ context = encoder_outputs[0]
882
+ if self.pool_with_global:
883
+ context[:, self.num_global_tokens] = context[:, 0]
884
+
885
+ diff = t - t_
886
+ n, _, d = context.size()
887
+ context = context[..., self.num_global_tokens:, :]
888
+
889
+ # Adapt sequence to initial shape
890
+ if diff < 0:
891
+ context = context[:, :t]
892
+
893
+ encoder_outputs.last_hidden_state = context
894
+ sequence_output = encoder_outputs[0]
895
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
896
+
897
+ if not return_dict:
898
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
899
+
900
+ return BaseModelOutputWithPooling(
901
+ last_hidden_state=sequence_output,
902
+ pooler_output=pooled_output,
903
+ hidden_states=encoder_outputs.hidden_states,
904
+ attentions=encoder_outputs.attentions,
905
+ )
906
+
907
+
908
+ class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
909
+
910
+ def __init__(self, config):
911
+
912
+ LSGAlbertPreTrainedModel.__init__(self, config)
913
+
914
+ self.albert = LSGAlbertModel(config)
915
+ self.predictions = AlbertMLMHead(config)
916
+ self.sop_classifier = AlbertSOPHead(config)
917
+
918
+ # Initialize weights and apply final processing
919
+ self.post_init()
920
+
921
+
922
+ class LSGAlbertForMaskedLM(LSGAlbertPreTrainedModel, AlbertForMaskedLM):
923
+
924
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
925
+
926
+ def __init__(self, config):
927
+ LSGAlbertPreTrainedModel.__init__(self, config)
928
+
929
+ self.albert = LSGAlbertModel(config, add_pooling_layer=False)
930
+ self.predictions = AlbertMLMHead(config)
931
+
932
+ # Initialize weights and apply final processing
933
+ self.post_init()
934
+
935
+
936
+ class LSGAlbertForSequenceClassification(LSGAlbertPreTrainedModel, AlbertForSequenceClassification):
937
+
938
+ def __init__(self, config):
939
+
940
+ LSGAlbertPreTrainedModel.__init__(self, config)
941
+ self.num_labels = config.num_labels
942
+ self.config = config
943
+
944
+ self.albert = LSGAlbertModel(config)
945
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
946
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
947
+
948
+ # Initialize weights and apply final processing
949
+ self.post_init()
950
+
951
+
952
+ class LSGAlbertForTokenClassification(LSGAlbertPreTrainedModel, AlbertForTokenClassification):
953
+
954
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
955
+
956
+ def __init__(self, config):
957
+
958
+ LSGAlbertPreTrainedModel.__init__(self, config)
959
+ self.num_labels = config.num_labels
960
+
961
+ self.albert = LSGAlbertModel(config, add_pooling_layer=False)
962
+ classifier_dropout_prob = (
963
+ config.classifier_dropout_prob
964
+ if config.classifier_dropout_prob is not None
965
+ else config.hidden_dropout_prob
966
+ )
967
+ self.dropout = nn.Dropout(classifier_dropout_prob)
968
+ self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
969
+
970
+ # Initialize weights and apply final processing
971
+ self.post_init()
972
+
973
+
974
+ class LSGAlbertForQuestionAnswering(LSGAlbertPreTrainedModel, AlbertForQuestionAnswering):
975
+
976
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
977
+
978
+ def __init__(self, config):
979
+
980
+ LSGAlbertPreTrainedModel.__init__(self, config)
981
+ self.num_labels = config.num_labels
982
+
983
+ self.albert = LSGAlbertModel(config, add_pooling_layer=False)
984
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
985
+
986
+ # Initialize weights and apply final processing
987
+ self.post_init()
988
+
989
+
990
+ class LSGAlbertForMultipleChoice(LSGAlbertPreTrainedModel, AlbertForMultipleChoice):
991
+
992
+ def __init__(self, config):
993
+
994
+ LSGAlbertPreTrainedModel.__init__(self, config)
995
+
996
+ self.albert = LSGAlbertModel(config)
997
+ self.dropout = nn.Dropout(config.classifier_dropout_prob)
998
+ self.classifier = nn.Linear(config.hidden_size, 1)
999
+
1000
+ # Initialize weights and apply final processing
1001
+ self.post_init()
1002
+
1003
+
1004
+ def str_to_class(classname):
1005
+ return getattr(sys.modules[__name__], classname)
1006
+
1007
+ # Register model in Auto API
1008
+ try:
1009
+ LSGAlbertConfig.register_for_auto_class()
1010
+ for key, value in AUTO_MAP.items():
1011
+ str_to_class(value.split(".")[-1]).register_for_auto_class(key)
1012
+ except:
1013
+ warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
1014
+ warn("Update to transformers >= 4.17.0 to fix.")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f0633412e0ab789a25bb0e397da111223a8b1bf154815084547d9c6ab50e1e6
3
+ size 47288075
special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "eos_token": "[SEP]",
5
+ "mask_token": {
6
+ "content": "[MASK]",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "[SEP]",
14
+ "unk_token": "<unk>"
15
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fefb02b667a6c5c2fe27602d28e5fb3428f66ab89c7d6f388e7c8d44a02d0336
3
+ size 760289
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "[CLS]",
3
+ "cls_token": "[CLS]",
4
+ "do_lower_case": true,
5
+ "eos_token": "[SEP]",
6
+ "keep_accents": false,
7
+ "mask_token": {
8
+ "__type": "AddedToken",
9
+ "content": "[MASK]",
10
+ "lstrip": true,
11
+ "normalized": false,
12
+ "rstrip": false,
13
+ "single_word": false
14
+ },
15
+ "model_max_length": 4096,
16
+ "name_or_path": "albert-base-v2",
17
+ "pad_token": "<pad>",
18
+ "remove_space": true,
19
+ "sep_token": "[SEP]",
20
+ "special_tokens_map_file": null,
21
+ "tokenizer_class": "AlbertTokenizer",
22
+ "unk_token": "<unk>"
23
+ }