zpn commited on
Commit
b49751a
·
1 Parent(s): 5b90141

Upload BertForMaskedLM

Browse files
Files changed (4) hide show
  1. config.json +42 -0
  2. configuring_nt_bert.py +162 -0
  3. modeling_nt_bert.py +999 -0
  4. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "single_bp_2k_step19999",
3
+ "architectures": [
4
+ "BertForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "attn_norm_layer_type": "layer_norm",
8
+ "attn_num_groups": 1,
9
+ "auto_map": {
10
+ "AutoConfig": "configuring_nt_bert.BertConfig",
11
+ "AutoModelForMaskedLM": "modeling_nt_bert.BertForMaskedLM"
12
+ },
13
+ "classifier_dropout": "None",
14
+ "embedding_norm_layer_type": "layer_norm",
15
+ "embedding_num_groups": 1,
16
+ "embedding_size": 1280,
17
+ "hidden_act": "gelu",
18
+ "hidden_dropout_prob": 0.1,
19
+ "hidden_size": 1280,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 5120,
22
+ "layer_norm_eps": 1e-12,
23
+ "max_position_embeddings": 2000,
24
+ "model_type": "bert",
25
+ "mup": true,
26
+ "num_attention_heads": 16,
27
+ "num_hidden_layers": 24,
28
+ "output_mult": 1,
29
+ "pad_token_id": 3,
30
+ "position_embedding_type": "alibi",
31
+ "prenorm": false,
32
+ "query_zero_init": false,
33
+ "readout_zero_init": false,
34
+ "summary_activation": "gelu",
35
+ "summary_last_dropout": 0.1,
36
+ "summary_type": "first",
37
+ "summary_use_proj": true,
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.25.1",
40
+ "type_vocab_size": 2,
41
+ "vocab_size": 10
42
+ }
configuring_nt_bert.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class BertConfig(PretrainedConfig):
5
+ r"""
6
+ This is the configuration class to store the configuration of a :class:`~transformers.ElectraModel` or a
7
+ :class:`~transformers.TFElectraModel`. It is used to instantiate a ELECTRA model according to the specified
8
+ arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar
9
+ configuration to that of the ELECTRA `google/electra-small-discriminator
10
+ <https://huggingface.co/google/electra-small-discriminator>`__ architecture.
11
+
12
+ Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
13
+ outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
14
+
15
+
16
+ Args:
17
+ vocab_size (:obj:`int`, `optional`, defaults to 30522):
18
+ Vocabulary size of the ELECTRA model. Defines the number of different tokens that can be represented by the
19
+ :obj:`inputs_ids` passed when calling :class:`~transformers.ElectraModel` or
20
+ :class:`~transformers.TFElectraModel`.
21
+ embedding_size (:obj:`int`, `optional`, defaults to 128):
22
+ Dimensionality of the encoder layers and the pooler layer.
23
+ hidden_size (:obj:`int`, `optional`, defaults to 256):
24
+ Dimensionality of the encoder layers and the pooler layer.
25
+ num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
26
+ Number of hidden layers in the Transformer encoder.
27
+ num_attention_heads (:obj:`int`, `optional`, defaults to 4):
28
+ Number of attention heads for each attention layer in the Transformer encoder.
29
+ intermediate_size (:obj:`int`, `optional`, defaults to 1024):
30
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
31
+ hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
32
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
33
+ :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
34
+ hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
35
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
36
+ attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
37
+ The dropout ratio for the attention probabilities.
38
+ max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
39
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
40
+ just in case (e.g., 512 or 1024 or 2048).
41
+ type_vocab_size (:obj:`int`, `optional`, defaults to 2):
42
+ The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.ElectraModel` or
43
+ :class:`~transformers.TFElectraModel`.
44
+ initializer_range (:obj:`float`, `optional`, defaults to 0.02):
45
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
46
+ layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
47
+ The epsilon used by the layer normalization layers.
48
+ summary_type (:obj:`str`, `optional`, defaults to :obj:`"first"`):
49
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
50
+
51
+ Has to be one of the following options:
52
+
53
+ - :obj:`"last"`: Take the last token hidden state (like XLNet).
54
+ - :obj:`"first"`: Take the first token hidden state (like BERT).
55
+ - :obj:`"mean"`: Take the mean of all tokens hidden states.
56
+ - :obj:`"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
57
+ - :obj:`"attn"`: Not implemented now, use multi-head attention.
58
+ summary_use_proj (:obj:`bool`, `optional`, defaults to :obj:`True`):
59
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
60
+
61
+ Whether or not to add a projection after the vector extraction.
62
+ summary_activation (:obj:`str`, `optional`):
63
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
64
+
65
+ Pass :obj:`"gelu"` for a gelu activation to the output, any other value will result in no activation.
66
+ summary_last_dropout (:obj:`float`, `optional`, defaults to 0.0):
67
+ Argument used when doing sequence summary. Used in the sequence classification and multiple choice models.
68
+
69
+ The dropout ratio to be used after the projection and activation.
70
+ position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
71
+ Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
72
+ :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
73
+ :obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
74
+ <https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
75
+ `Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
76
+ <https://arxiv.org/abs/2009.13658>`__.
77
+ classifier_dropout (:obj:`float`, `optional`):
78
+ The dropout ratio for the classification head.
79
+
80
+ Examples::
81
+
82
+ >>> from transformers import ElectraModel, ElectraConfig
83
+
84
+ >>> # Initializing a ELECTRA electra-base-uncased style configuration
85
+ >>> configuration = ElectraConfig()
86
+
87
+ >>> # Initializing a model from the electra-base-uncased style configuration
88
+ >>> model = ElectraModel(configuration)
89
+
90
+ >>> # Accessing the model configuration
91
+ >>> configuration = model.config
92
+ """
93
+ model_type = "bert"
94
+
95
+ def __init__(
96
+ self,
97
+ vocab_size=30522,
98
+ embedding_size=128,
99
+ hidden_size=256,
100
+ num_hidden_layers=12,
101
+ num_attention_heads=4,
102
+ intermediate_size=1024,
103
+ hidden_act="gelu",
104
+ hidden_dropout_prob=0.1,
105
+ attention_probs_dropout_prob=0.1,
106
+ max_position_embeddings=512,
107
+ type_vocab_size=2,
108
+ initializer_range=0.02,
109
+ layer_norm_eps=1e-12,
110
+ summary_type="first",
111
+ summary_use_proj=True,
112
+ summary_activation="gelu",
113
+ summary_last_dropout=0.1,
114
+ pad_token_id=0,
115
+ position_embedding_type="absolute",
116
+ classifier_dropout=None,
117
+ prenorm=False,
118
+ mup=False,
119
+ embedding_norm_layer_type="layer_norm",
120
+ embedding_num_groups=1,
121
+ attn_norm_layer_type="layer_norm",
122
+ attn_num_groups=1,
123
+ output_mult=1,
124
+ readout_zero_init=False,
125
+ query_zero_init=False,
126
+ **kwargs,
127
+ ):
128
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
129
+
130
+ self.vocab_size = vocab_size
131
+ self.embedding_size = embedding_size
132
+ self.hidden_size = hidden_size
133
+ self.num_hidden_layers = num_hidden_layers
134
+ self.num_attention_heads = num_attention_heads
135
+ self.intermediate_size = intermediate_size
136
+ self.hidden_act = hidden_act
137
+ self.hidden_dropout_prob = hidden_dropout_prob
138
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
139
+ self.max_position_embeddings = max_position_embeddings
140
+ self.type_vocab_size = type_vocab_size
141
+ self.initializer_range = initializer_range
142
+ self.layer_norm_eps = layer_norm_eps
143
+ # passing in 1e-x in config turns to string
144
+ if isinstance(self.layer_norm_eps, str):
145
+ self.layer_norm_eps = float(self.layer_norm_eps)
146
+
147
+ self.summary_type = summary_type
148
+ self.summary_use_proj = summary_use_proj
149
+ self.summary_activation = summary_activation
150
+ self.summary_last_dropout = summary_last_dropout
151
+ self.position_embedding_type = position_embedding_type
152
+ self.classifier_dropout = classifier_dropout
153
+ # transformers without tears suggests using prenorm
154
+ self.prenorm = prenorm
155
+ self.mup = mup
156
+ self.embedding_norm_layer_type = embedding_norm_layer_type
157
+ self.embedding_num_groups = embedding_num_groups
158
+ self.attn_norm_layer_type = attn_norm_layer_type
159
+ self.attn_num_groups = attn_num_groups
160
+ self.output_mult = output_mult
161
+ self.readout_zero_init = readout_zero_init
162
+ self.query_zero_init = query_zero_init
modeling_nt_bert.py ADDED
@@ -0,0 +1,999 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from huggingface_hub import hf_hub_download
8
+ from mup import MuReadout, set_base_shapes
9
+ from mup.init import normal_
10
+ from nt_transformer.models.nt_bert.configuring_nt_bert import BertConfig
11
+ from rotary_embedding_torch import RotaryEmbedding
12
+ from transformers.modeling_outputs import (
13
+ BaseModelOutputWithPastAndCrossAttentions,
14
+ MaskedLMOutput,
15
+ )
16
+ from transformers.modeling_utils import (
17
+ PreTrainedModel,
18
+ apply_chunking_to_forward,
19
+ find_pruneable_heads_and_indices,
20
+ get_activation,
21
+ prune_linear_layer,
22
+ )
23
+
24
+
25
+ class BertPreTrainedModel(PreTrainedModel):
26
+ """
27
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
28
+ models.
29
+ """
30
+
31
+ config_class = BertConfig
32
+ base_model_prefix = "bert"
33
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
34
+ _keys_to_ignore_on_load_unexpected = [
35
+ r"bert\.embeddings_project\.weight",
36
+ r"bert\.embeddings_project\.bias",
37
+ ]
38
+
39
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
40
+ def _init_weights(self, module, readout_zero_init=False, query_zero_init=False):
41
+ """Initialize the weights"""
42
+ if isinstance(module, nn.Linear):
43
+ # Slightly different from the TF version which uses truncated_normal for initialization
44
+ # cf https://github.com/pytorch/pytorch/pull/5617
45
+ ### muP: swap constant std normal init with normal_ from `mup.init`.
46
+ ### Because `_init_weights` is called in `__init__`, before `infshape` is set,
47
+ ### we need to manually call `self.apply(self._init_weights)` after calling
48
+ ### `set_base_shape(model, base)`
49
+ if isinstance(module, MuReadout) and readout_zero_init:
50
+ module.weight.data.zero_()
51
+ else:
52
+ if hasattr(module.weight, "infshape"):
53
+ normal_(module.weight, mean=0.0, std=self.config.initializer_range)
54
+ else:
55
+ module.weight.data.normal_(
56
+ mean=0.0, std=self.config.initializer_range
57
+ )
58
+ ### End muP
59
+ if module.bias is not None:
60
+ module.bias.data.zero_()
61
+ elif isinstance(module, nn.Embedding):
62
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
63
+ if module.padding_idx is not None:
64
+ module.weight.data[module.padding_idx].zero_()
65
+ elif isinstance(module, nn.LayerNorm):
66
+ module.bias.data.zero_()
67
+ module.weight.data.fill_(1.0)
68
+ ### muP
69
+ if isinstance(module, BertSelfAttention):
70
+ if query_zero_init:
71
+ module.query.weight.data[:] = 0
72
+
73
+ @classmethod
74
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
75
+ model = super().from_pretrained(
76
+ pretrained_model_name_or_path, *model_args, **kwargs
77
+ )
78
+
79
+ # since we used MuP, need to reset values since they're not saved with the model
80
+ if os.path.exists("base_shapes.bsh") is False:
81
+ hf_hub_download(
82
+ "zpn/human_bp_bert", "base_shapes.bsh"
83
+ )
84
+
85
+ set_base_shapes(model, "base_shapes.bsh", rescale_params=False)
86
+
87
+ return model
88
+
89
+
90
+ class BertEmbeddings(nn.Module):
91
+ """Construct the embeddings from word, position and token_type embeddings."""
92
+
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.word_embeddings = nn.Embedding(
96
+ config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id
97
+ )
98
+ self.position_embeddings = nn.Embedding(
99
+ config.max_position_embeddings, config.embedding_size
100
+ )
101
+ self.token_type_embeddings = nn.Embedding(
102
+ config.type_vocab_size, config.embedding_size
103
+ )
104
+
105
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
106
+ # any TensorFlow checkpoint file
107
+
108
+ if config.embedding_norm_layer_type == "layer_norm":
109
+ self.norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
110
+ elif config.embedding_norm_layer_type == "group_norm":
111
+ self.norm = nn.GroupNorm(
112
+ num_groups=config.embedding_num_groups,
113
+ num_channels=config.embedding_size,
114
+ )
115
+ else:
116
+ raise ValueError(
117
+ f"Unknown attn_norm_layer_type {config.attn_norm_layer_type}"
118
+ )
119
+
120
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
121
+
122
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
123
+ self.register_buffer(
124
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
125
+ )
126
+ self.position_embedding_type = getattr(
127
+ config, "position_embedding_type", "absolute"
128
+ )
129
+
130
+ self.register_buffer(
131
+ "token_type_ids",
132
+ torch.zeros(
133
+ self.position_ids.size(),
134
+ dtype=torch.long,
135
+ device=self.position_ids.device,
136
+ ),
137
+ persistent=False,
138
+ )
139
+
140
+ def forward(
141
+ self,
142
+ input_ids=None,
143
+ token_type_ids=None,
144
+ position_ids=None,
145
+ inputs_embeds=None,
146
+ past_key_values_length=0,
147
+ ):
148
+ if input_ids is not None:
149
+ input_shape = input_ids.size()
150
+ else:
151
+ input_shape = inputs_embeds.size()[:-1]
152
+
153
+ seq_length = input_shape[1]
154
+
155
+ if position_ids is None:
156
+ position_ids = self.position_ids[
157
+ :, past_key_values_length : seq_length + past_key_values_length
158
+ ]
159
+
160
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
161
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
162
+ # issue #5664
163
+ if token_type_ids is None:
164
+ if hasattr(self, "token_type_ids"):
165
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
166
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
167
+ input_shape[0], seq_length
168
+ )
169
+ token_type_ids = buffered_token_type_ids_expanded
170
+ else:
171
+ token_type_ids = torch.zeros(
172
+ input_shape, dtype=torch.long, device=self.position_ids.device
173
+ )
174
+
175
+ if inputs_embeds is None:
176
+ inputs_embeds = self.word_embeddings(input_ids)
177
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
178
+
179
+ embeddings = inputs_embeds + token_type_embeddings
180
+ if self.position_embedding_type == "absolute":
181
+ position_embeddings = self.position_embeddings(position_ids)
182
+ embeddings += position_embeddings
183
+
184
+ if isinstance(self.norm, nn.GroupNorm):
185
+ # group norm only works over channel dim
186
+ reshaped = embeddings.permute(0, 2, 1)
187
+ embeddings = self.norm(reshaped)
188
+ embeddings = embeddings.permute(0, 2, 1)
189
+ else:
190
+ embeddings = self.norm(embeddings)
191
+
192
+ embeddings = self.dropout(embeddings)
193
+ return embeddings
194
+
195
+
196
+ class BertIntermediate(nn.Module):
197
+ def __init__(self, config):
198
+ super().__init__()
199
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
200
+ if isinstance(config.hidden_act, str):
201
+ self.intermediate_act_fn = get_activation(config.hidden_act)
202
+ else:
203
+ self.intermediate_act_fn = config.hidden_act
204
+
205
+ def forward(self, hidden_states):
206
+ hidden_states = self.dense(hidden_states)
207
+ hidden_states = self.intermediate_act_fn(hidden_states)
208
+ return hidden_states
209
+
210
+
211
+ class BertLayer(nn.Module):
212
+ def __init__(self, config):
213
+ super().__init__()
214
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
215
+ self.seq_len_dim = 1
216
+ self.attention = BertAttention(config)
217
+ self.is_decoder = config.is_decoder
218
+ self.add_cross_attention = config.add_cross_attention
219
+ if self.add_cross_attention:
220
+ assert (
221
+ self.is_decoder
222
+ ), f"{self} should be used as a decoder model if cross attention is added"
223
+ self.crossattention = BertAttention(config)
224
+ self.intermediate = BertIntermediate(config)
225
+ self.output = BertOutput(config)
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states,
230
+ attention_mask=None,
231
+ head_mask=None,
232
+ encoder_hidden_states=None,
233
+ encoder_attention_mask=None,
234
+ past_key_value=None,
235
+ output_attentions=False,
236
+ ):
237
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
238
+ self_attn_past_key_value = (
239
+ past_key_value[:2] if past_key_value is not None else None
240
+ )
241
+ self_attention_outputs = self.attention(
242
+ hidden_states,
243
+ attention_mask,
244
+ head_mask,
245
+ output_attentions=output_attentions,
246
+ past_key_value=self_attn_past_key_value,
247
+ )
248
+ attention_output = self_attention_outputs[0]
249
+
250
+ # if decoder, the last output is tuple of self-attn cache
251
+ if self.is_decoder:
252
+ outputs = self_attention_outputs[1:-1]
253
+ present_key_value = self_attention_outputs[-1]
254
+ else:
255
+ outputs = self_attention_outputs[
256
+ 1:
257
+ ] # add self attentions if we output attention weights
258
+
259
+ cross_attn_present_key_value = None
260
+ if self.is_decoder and encoder_hidden_states is not None:
261
+ assert hasattr(
262
+ self, "crossattention"
263
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
264
+
265
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
266
+ cross_attn_past_key_value = (
267
+ past_key_value[-2:] if past_key_value is not None else None
268
+ )
269
+ cross_attention_outputs = self.crossattention(
270
+ attention_output,
271
+ attention_mask,
272
+ head_mask,
273
+ encoder_hidden_states,
274
+ encoder_attention_mask,
275
+ cross_attn_past_key_value,
276
+ output_attentions,
277
+ )
278
+ attention_output = cross_attention_outputs[0]
279
+ outputs = (
280
+ outputs + cross_attention_outputs[1:-1]
281
+ ) # add cross attentions if we output attention weights
282
+
283
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
284
+ cross_attn_present_key_value = cross_attention_outputs[-1]
285
+ present_key_value = present_key_value + cross_attn_present_key_value
286
+
287
+ layer_output = apply_chunking_to_forward(
288
+ self.feed_forward_chunk,
289
+ self.chunk_size_feed_forward,
290
+ self.seq_len_dim,
291
+ attention_output,
292
+ )
293
+ outputs = (layer_output,) + outputs
294
+
295
+ # if decoder, return the attn key/values as the last output
296
+ if self.is_decoder:
297
+ outputs = outputs + (present_key_value,)
298
+
299
+ return outputs
300
+
301
+ def feed_forward_chunk(self, attention_output):
302
+ intermediate_output = self.intermediate(attention_output)
303
+ layer_output = self.output(intermediate_output, attention_output)
304
+ return layer_output
305
+
306
+
307
+ class BertEncoder(nn.Module):
308
+ def __init__(self, config):
309
+ super().__init__()
310
+ self.config = config
311
+ self.layer = nn.ModuleList(
312
+ [BertLayer(config) for _ in range(config.num_hidden_layers)]
313
+ )
314
+
315
+ def forward(
316
+ self,
317
+ hidden_states,
318
+ attention_mask=None,
319
+ head_mask=None,
320
+ encoder_hidden_states=None,
321
+ encoder_attention_mask=None,
322
+ past_key_values=None,
323
+ use_cache=None,
324
+ output_attentions=False,
325
+ output_hidden_states=False,
326
+ return_dict=True,
327
+ ):
328
+ all_hidden_states = () if output_hidden_states else None
329
+ all_self_attentions = () if output_attentions else None
330
+ all_cross_attentions = (
331
+ () if output_attentions and self.config.add_cross_attention else None
332
+ )
333
+
334
+ next_decoder_cache = () if use_cache else None
335
+ for i, layer_module in enumerate(self.layer):
336
+ if output_hidden_states:
337
+ all_hidden_states = all_hidden_states + (hidden_states,)
338
+
339
+ layer_head_mask = head_mask[i] if head_mask is not None else None
340
+ past_key_value = past_key_values[i] if past_key_values is not None else None
341
+
342
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
343
+ if use_cache:
344
+ use_cache = False
345
+
346
+ def create_custom_forward(module):
347
+ def custom_forward(*inputs):
348
+ return module(*inputs, past_key_value, output_attentions)
349
+
350
+ return custom_forward
351
+
352
+ layer_outputs = torch.utils.checkpoint.checkpoint(
353
+ create_custom_forward(layer_module),
354
+ hidden_states,
355
+ attention_mask,
356
+ layer_head_mask,
357
+ encoder_hidden_states,
358
+ encoder_attention_mask,
359
+ )
360
+ else:
361
+ layer_outputs = layer_module(
362
+ hidden_states,
363
+ attention_mask,
364
+ layer_head_mask,
365
+ encoder_hidden_states,
366
+ encoder_attention_mask,
367
+ past_key_value,
368
+ output_attentions,
369
+ )
370
+
371
+ hidden_states = layer_outputs[0]
372
+ if use_cache:
373
+ next_decoder_cache += (layer_outputs[-1],)
374
+ if output_attentions:
375
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
376
+ if self.config.add_cross_attention:
377
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
378
+
379
+ if output_hidden_states:
380
+ all_hidden_states = all_hidden_states + (hidden_states,)
381
+
382
+ if not return_dict:
383
+ return tuple(
384
+ v
385
+ for v in [
386
+ hidden_states,
387
+ next_decoder_cache,
388
+ all_hidden_states,
389
+ all_self_attentions,
390
+ all_cross_attentions,
391
+ ]
392
+ if v is not None
393
+ )
394
+ return BaseModelOutputWithPastAndCrossAttentions(
395
+ last_hidden_state=hidden_states,
396
+ past_key_values=next_decoder_cache,
397
+ hidden_states=all_hidden_states,
398
+ attentions=all_self_attentions,
399
+ cross_attentions=all_cross_attentions,
400
+ )
401
+
402
+
403
+ class BertOutput(nn.Module):
404
+ def __init__(self, config):
405
+ super().__init__()
406
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
407
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
408
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
409
+
410
+ def forward(self, hidden_states, input_tensor):
411
+ hidden_states = self.dense(hidden_states)
412
+ hidden_states = self.dropout(hidden_states)
413
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
414
+ return hidden_states
415
+
416
+
417
+ # shamelessly stolen from: https://github.com/lucidrains/x-transformers/blob/fb1671342d3b27a748336873c225fbd4dd66b7a1/x_transformers/x_transformers.py#L267
418
+ class AlibiPositionalBias(nn.Module):
419
+ def __init__(self, heads, **kwargs):
420
+ super().__init__()
421
+ self.heads = heads
422
+ slopes = torch.Tensor(self._get_slopes(heads))
423
+ slopes = rearrange(slopes, "h -> h 1 1")
424
+ self.register_buffer("slopes", slopes, persistent=False)
425
+ self.register_buffer("bias", None, persistent=False)
426
+
427
+ def get_bias(self, i, j, device):
428
+ i_arange = torch.arange(j - i, j, device=device)
429
+ j_arange = torch.arange(j, device=device)
430
+ bias = -torch.abs(
431
+ rearrange(j_arange, "j -> 1 1 j") - rearrange(i_arange, "i -> 1 i 1")
432
+ )
433
+ return bias
434
+
435
+ @staticmethod
436
+ def _get_slopes(heads):
437
+ def get_slopes_power_of_2(n):
438
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
439
+ ratio = start
440
+ return [start * ratio**i for i in range(n)]
441
+
442
+ if math.log2(heads).is_integer():
443
+ return get_slopes_power_of_2(heads)
444
+
445
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
446
+ return (
447
+ get_slopes_power_of_2(closest_power_of_2)
448
+ + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][
449
+ : heads - closest_power_of_2
450
+ ]
451
+ )
452
+
453
+ def forward(self, qk_dots):
454
+ h, i, j, device = *qk_dots.shape[-3:], qk_dots.device
455
+
456
+ if self.bias is not None and self.bias.shape[-1] >= j:
457
+ return qk_dots + self.bias[..., :i, :j]
458
+
459
+ bias = self.get_bias(i, j, device)
460
+ bias = bias * self.slopes
461
+
462
+ num_heads_unalibied = h - bias.shape[0]
463
+ bias = F.pad(bias, (0, 0, 0, 0, 0, num_heads_unalibied))
464
+ self.register_buffer("bias", bias, persistent=False)
465
+
466
+ return qk_dots + self.bias
467
+
468
+
469
+ class BertModel(BertPreTrainedModel):
470
+ def __init__(self, config):
471
+ super().__init__(config)
472
+ self.embeddings = BertEmbeddings(config)
473
+
474
+ if config.embedding_size != config.hidden_size:
475
+ self.embeddings_project = nn.Linear(
476
+ config.embedding_size, config.hidden_size
477
+ )
478
+
479
+ self.encoder = BertEncoder(config)
480
+ self.config = config
481
+ self.init_weights()
482
+
483
+ def get_input_embeddings(self):
484
+ return self.embeddings.word_embeddings
485
+
486
+ def set_input_embeddings(self, value):
487
+ self.embeddings.word_embeddings = value
488
+
489
+ def _prune_heads(self, heads_to_prune):
490
+ """
491
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
492
+ class PreTrainedModel
493
+ """
494
+ for layer, heads in heads_to_prune.items():
495
+ self.encoder.layer[layer].attention.prune_heads(heads)
496
+
497
+ def forward(
498
+ self,
499
+ input_ids=None,
500
+ attention_mask=None,
501
+ token_type_ids=None,
502
+ position_ids=None,
503
+ head_mask=None,
504
+ inputs_embeds=None,
505
+ output_attentions=None,
506
+ output_hidden_states=None,
507
+ return_dict=None,
508
+ ):
509
+ output_attentions = (
510
+ output_attentions
511
+ if output_attentions is not None
512
+ else self.config.output_attentions
513
+ )
514
+ output_hidden_states = (
515
+ output_hidden_states
516
+ if output_hidden_states is not None
517
+ else self.config.output_hidden_states
518
+ )
519
+ return_dict = (
520
+ return_dict if return_dict is not None else self.config.use_return_dict
521
+ )
522
+
523
+ if input_ids is not None and inputs_embeds is not None:
524
+ raise ValueError(
525
+ "You cannot specify both input_ids and inputs_embeds at the same time"
526
+ )
527
+ elif input_ids is not None:
528
+ input_shape = input_ids.size()
529
+ batch_size, seq_length = input_shape
530
+ elif inputs_embeds is not None:
531
+ input_shape = inputs_embeds.size()[:-1]
532
+ else:
533
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
534
+
535
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
536
+
537
+ if attention_mask is None:
538
+ attention_mask = torch.ones(input_shape, device=device)
539
+ if token_type_ids is None:
540
+ if hasattr(self.embeddings, "token_type_ids"):
541
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
542
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(
543
+ batch_size, seq_length
544
+ )
545
+ token_type_ids = buffered_token_type_ids_expanded
546
+ else:
547
+ token_type_ids = torch.zeros(
548
+ input_shape, dtype=torch.long, device=device
549
+ )
550
+
551
+ extended_attention_mask = self.get_extended_attention_mask(
552
+ attention_mask, input_shape, device
553
+ )
554
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
555
+
556
+ hidden_states = self.embeddings(
557
+ input_ids=input_ids,
558
+ position_ids=position_ids,
559
+ token_type_ids=token_type_ids,
560
+ inputs_embeds=inputs_embeds,
561
+ )
562
+
563
+ if hasattr(self, "embeddings_project"):
564
+ hidden_states = self.embeddings_project(hidden_states)
565
+
566
+ hidden_states = self.encoder(
567
+ hidden_states,
568
+ attention_mask=extended_attention_mask,
569
+ head_mask=head_mask,
570
+ output_attentions=output_attentions,
571
+ output_hidden_states=output_hidden_states,
572
+ return_dict=return_dict,
573
+ )
574
+
575
+ return hidden_states
576
+
577
+
578
+ class BertSelfOutput(nn.Module):
579
+ def __init__(self, config):
580
+ super().__init__()
581
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
582
+ if config.prenorm:
583
+ self.norm = nn.Identity()
584
+ else:
585
+ if config.attn_norm_layer_type == "layer_norm":
586
+ self.norm = nn.LayerNorm(config.hidden_size)
587
+ elif config.attn_norm_layer_type == "group_norm":
588
+ self.norm = nn.GroupNorm(
589
+ num_groups=config.attn_num_groups, num_channels=config.hidden_size
590
+ )
591
+ else:
592
+ raise ValueError(
593
+ f"Unknown attn_norm_layer_type {config.attn_norm_layer_type}"
594
+ )
595
+
596
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
597
+
598
+ def forward(self, hidden_states, input_tensor):
599
+ hidden_states = self.dense(hidden_states)
600
+ hidden_states = self.dropout(hidden_states)
601
+ if isinstance(self.norm, nn.GroupNorm):
602
+ reshaped = hidden_states + input_tensor
603
+ # group norm only works over channel dim
604
+ reshaped = reshaped.permute(0, 2, 1)
605
+ hidden_states = self.norm(reshaped)
606
+ hidden_states = hidden_states.permute(0, 2, 1)
607
+ else:
608
+ hidden_states = self.norm(hidden_states + input_tensor)
609
+ return hidden_states
610
+
611
+
612
+ class BertSelfAttention(nn.Module):
613
+ def __init__(self, config):
614
+ super().__init__()
615
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
616
+ config, "embedding_size"
617
+ ):
618
+ raise ValueError(
619
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
620
+ f"heads ({config.num_attention_heads})"
621
+ )
622
+
623
+ self.num_attention_heads = config.num_attention_heads
624
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
625
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
626
+
627
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
628
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
629
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
630
+
631
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
632
+ self.position_embedding_type = getattr(
633
+ config, "position_embedding_type", "absolute"
634
+ )
635
+ if (
636
+ self.position_embedding_type == "relative_key"
637
+ or self.position_embedding_type == "relative_key_query"
638
+ ):
639
+ self.max_position_embeddings = config.max_position_embeddings
640
+ self.distance_embedding = nn.Embedding(
641
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
642
+ )
643
+ elif self.position_embedding_type == "rotary":
644
+ self.rotary = RotaryEmbedding(dim=self.attention_head_size)
645
+ elif self.position_embedding_type == "alibi":
646
+ self.alibi = AlibiPositionalBias(self.num_attention_heads)
647
+
648
+ self.is_decoder = config.is_decoder
649
+
650
+ if config.mup:
651
+ self.attention_scaling_factor = self.attention_head_size
652
+ else:
653
+ self.attention_scaling_factor = math.sqrt(self.attention_head_size)
654
+
655
+ def transpose_for_scores(self, x):
656
+ new_x_shape = x.size()[:-1] + (
657
+ self.num_attention_heads,
658
+ self.attention_head_size,
659
+ )
660
+ x = x.view(*new_x_shape)
661
+ return x.permute(0, 2, 1, 3)
662
+
663
+ def forward(
664
+ self,
665
+ hidden_states,
666
+ attention_mask=None,
667
+ head_mask=None,
668
+ encoder_hidden_states=None,
669
+ encoder_attention_mask=None,
670
+ past_key_value=None,
671
+ output_attentions=False,
672
+ ):
673
+ mixed_query_layer = self.query(hidden_states)
674
+
675
+ # If this is instantiated as a cross-attention module, the keys
676
+ # and values come from an encoder; the attention mask needs to be
677
+ # such that the encoder's padding tokens are not attended to.
678
+ is_cross_attention = encoder_hidden_states is not None
679
+
680
+ if is_cross_attention and past_key_value is not None:
681
+ # reuse k,v, cross_attentions
682
+ key_layer = past_key_value[0]
683
+ value_layer = past_key_value[1]
684
+ attention_mask = encoder_attention_mask
685
+ elif is_cross_attention:
686
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
687
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
688
+ attention_mask = encoder_attention_mask
689
+ elif past_key_value is not None:
690
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
691
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
692
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
693
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
694
+ else:
695
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
696
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
697
+
698
+ query_layer = self.transpose_for_scores(mixed_query_layer)
699
+
700
+ if self.is_decoder:
701
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
702
+ # Further calls to cross_attention layer can then reuse all cross-attention
703
+ # key/value_states (first "if" case)
704
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
705
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
706
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
707
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
708
+ past_key_value = (key_layer, value_layer)
709
+
710
+ if self.position_embedding_type == "rotary":
711
+ query_layer = self.rotary.rotate_queries_or_keys(query_layer)
712
+ key_layer = self.rotary.rotate_queries_or_keys(key_layer)
713
+
714
+ # Take the dot product between "query" and "key" to get the raw attention scores.
715
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
716
+
717
+ if (
718
+ self.position_embedding_type == "relative_key"
719
+ or self.position_embedding_type == "relative_key_query"
720
+ ):
721
+ seq_length = hidden_states.size()[1]
722
+ position_ids_l = torch.arange(
723
+ seq_length, dtype=torch.long, device=hidden_states.device
724
+ ).view(-1, 1)
725
+ position_ids_r = torch.arange(
726
+ seq_length, dtype=torch.long, device=hidden_states.device
727
+ ).view(1, -1)
728
+ distance = position_ids_l - position_ids_r
729
+ positional_embedding = self.distance_embedding(
730
+ distance + self.max_position_embeddings - 1
731
+ )
732
+ positional_embedding = positional_embedding.to(
733
+ dtype=query_layer.dtype
734
+ ) # fp16 compatibility
735
+
736
+ if self.position_embedding_type == "relative_key":
737
+ relative_position_scores = torch.einsum(
738
+ "bhld,lrd->bhlr", query_layer, positional_embedding
739
+ )
740
+ attention_scores = attention_scores + relative_position_scores
741
+ elif self.position_embedding_type == "relative_key_query":
742
+ relative_position_scores_query = torch.einsum(
743
+ "bhld,lrd->bhlr", query_layer, positional_embedding
744
+ )
745
+ relative_position_scores_key = torch.einsum(
746
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
747
+ )
748
+ attention_scores = (
749
+ attention_scores
750
+ + relative_position_scores_query
751
+ + relative_position_scores_key
752
+ )
753
+
754
+ # attention scaling -> for mup need to rescale to 1/d
755
+ attention_scores = attention_scores / self.attention_scaling_factor
756
+
757
+ if self.position_embedding_type == "alibi":
758
+ attention_scores = self.alibi(attention_scores)
759
+
760
+ if attention_mask is not None:
761
+ # Apply the attention mask is (precomputed for all layers in ElectraModel forward() function)
762
+ attention_scores = attention_scores + attention_mask
763
+
764
+ # Normalize the attention scores to probabilities.
765
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
766
+
767
+ # This is actually dropping out entire tokens to attend to, which might
768
+ # seem a bit unusual, but is taken from the original Transformer paper.
769
+ attention_probs = self.dropout(attention_probs)
770
+
771
+ # Mask heads if we want to
772
+ if head_mask is not None:
773
+ attention_probs = attention_probs * head_mask
774
+
775
+ context_layer = torch.matmul(attention_probs, value_layer)
776
+
777
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
778
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
779
+ context_layer = context_layer.view(*new_context_layer_shape)
780
+
781
+ outputs = (
782
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
783
+ )
784
+
785
+ if self.is_decoder:
786
+ outputs = outputs + (past_key_value,)
787
+ return outputs
788
+
789
+
790
+ class BertAttention(nn.Module):
791
+ def __init__(self, config):
792
+ super().__init__()
793
+ self.self = BertSelfAttention(config)
794
+ self.output = BertSelfOutput(config)
795
+ if config.prenorm:
796
+ if config.attn_norm_layer_type == "layer_norm":
797
+ self.prenorm = nn.LayerNorm(
798
+ config.hidden_size, eps=config.layer_norm_eps
799
+ )
800
+ elif config.attn_norm_layer_type == "group_norm":
801
+ self.prenorm = nn.GroupNorm(
802
+ num_groups=config.attn_num_groups,
803
+ num_channels=config.hidden_size,
804
+ eps=config.layer_norm_eps,
805
+ )
806
+ else:
807
+ raise ValueError(
808
+ f"Unknown attn_norm_layer_type {config.attn_norm_layer_type}"
809
+ )
810
+ else:
811
+ self.prenorm = nn.Identity()
812
+
813
+ self.pruned_heads = set()
814
+
815
+ def prune_heads(self, heads):
816
+ if len(heads) == 0:
817
+ return
818
+ heads, index = find_pruneable_heads_and_indices(
819
+ heads,
820
+ self.self.num_attention_heads,
821
+ self.self.attention_head_size,
822
+ self.pruned_heads,
823
+ )
824
+
825
+ # Prune linear layers
826
+ self.self.query = prune_linear_layer(self.self.query, index)
827
+ self.self.key = prune_linear_layer(self.self.key, index)
828
+ self.self.value = prune_linear_layer(self.self.value, index)
829
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
830
+
831
+ # Update hyper params and store pruned heads
832
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
833
+ self.self.all_head_size = (
834
+ self.self.attention_head_size * self.self.num_attention_heads
835
+ )
836
+ self.pruned_heads = self.pruned_heads.union(heads)
837
+
838
+ def forward(
839
+ self,
840
+ hidden_states,
841
+ attention_mask=None,
842
+ head_mask=None,
843
+ encoder_hidden_states=None,
844
+ encoder_attention_mask=None,
845
+ past_key_value=None,
846
+ output_attentions=False,
847
+ ):
848
+ # if we are doing prenorm instead of postnorm
849
+ if isinstance(self.prenorm, nn.GroupNorm):
850
+ # group norm only works over channel dim
851
+ reshaped = hidden_states.permute(0, 2, 1)
852
+ hidden_states = self.prenorm(reshaped)
853
+ hidden_states = hidden_states.permute(0, 2, 1)
854
+ else:
855
+ hidden_states = self.prenorm(hidden_states)
856
+
857
+ self_outputs = self.self(
858
+ hidden_states,
859
+ attention_mask,
860
+ head_mask,
861
+ encoder_hidden_states,
862
+ encoder_attention_mask,
863
+ past_key_value,
864
+ output_attentions,
865
+ )
866
+ attention_output = self.output(self_outputs[0], hidden_states)
867
+ outputs = (attention_output,) + self_outputs[
868
+ 1:
869
+ ] # add attentions if we output them
870
+ return outputs
871
+
872
+
873
+ class BertPredictionHeadTransform(nn.Module):
874
+ def __init__(self, config):
875
+ super().__init__()
876
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
877
+ if isinstance(config.hidden_act, str):
878
+ self.transform_act_fn = get_activation(config.hidden_act)
879
+ else:
880
+ self.transform_act_fn = config.hidden_act
881
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
882
+
883
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
884
+ hidden_states = self.dense(hidden_states)
885
+ hidden_states = self.transform_act_fn(hidden_states)
886
+ hidden_states = self.LayerNorm(hidden_states)
887
+ return hidden_states
888
+
889
+
890
+ class BertLMPredictionHead(nn.Module):
891
+ def __init__(self, config):
892
+ super().__init__()
893
+ self.transform = BertPredictionHeadTransform(config)
894
+
895
+ # The output weights are the same as the input embeddings, but there is
896
+ # an output-only bias for each token.
897
+ if config.mup:
898
+ self.decoder = MuReadout(
899
+ config.hidden_size,
900
+ config.vocab_size,
901
+ output_mult=config.output_mult,
902
+ readout_zero_init=config.readout_zero_init,
903
+ bias=False,
904
+ )
905
+ else:
906
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
907
+
908
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
909
+
910
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
911
+ self.decoder.bias = self.bias
912
+
913
+ def forward(self, hidden_states):
914
+ hidden_states = self.transform(hidden_states)
915
+ hidden_states = self.decoder(hidden_states)
916
+ return hidden_states
917
+
918
+
919
+ class BertOnlyMLMHead(nn.Module):
920
+ def __init__(self, config):
921
+ super().__init__()
922
+ self.predictions = BertLMPredictionHead(config)
923
+
924
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
925
+ prediction_scores = self.predictions(sequence_output)
926
+ return prediction_scores
927
+
928
+
929
+ class BertForMaskedLM(BertPreTrainedModel):
930
+ def __init__(self, config):
931
+ super().__init__(config)
932
+
933
+ self.bert = BertModel(config)
934
+ self.cls = BertOnlyMLMHead(config)
935
+
936
+ self.init_weights()
937
+
938
+ def get_output_embeddings(self):
939
+ return self.cls.predictions.decoder
940
+
941
+ def set_output_embeddings(self, new_embeddings):
942
+ self.cls.predictions.decoder = new_embeddings
943
+
944
+ def forward(
945
+ self,
946
+ input_ids=None,
947
+ attention_mask=None,
948
+ token_type_ids=None,
949
+ position_ids=None,
950
+ head_mask=None,
951
+ inputs_embeds=None,
952
+ labels=None,
953
+ output_attentions=None,
954
+ output_hidden_states=None,
955
+ return_dict=None,
956
+ ):
957
+ r"""
958
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
959
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
960
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
961
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
962
+ """
963
+ return_dict = (
964
+ return_dict if return_dict is not None else self.config.use_return_dict
965
+ )
966
+
967
+ outputs = self.bert(
968
+ input_ids,
969
+ attention_mask,
970
+ token_type_ids,
971
+ position_ids,
972
+ head_mask,
973
+ inputs_embeds,
974
+ output_attentions,
975
+ output_hidden_states,
976
+ return_dict,
977
+ )
978
+
979
+ sequence_output = outputs[0]
980
+ prediction_scores = self.cls(sequence_output)
981
+
982
+ loss = None
983
+ # Masked language modeling softmax layer
984
+ if labels is not None:
985
+ loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
986
+ loss = loss_fct(
987
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
988
+ )
989
+
990
+ if not return_dict:
991
+ output = (prediction_scores,) + outputs[2:]
992
+ return ((loss,) + output) if loss is not None else output
993
+
994
+ return MaskedLMOutput(
995
+ loss=loss,
996
+ logits=prediction_scores,
997
+ hidden_states=outputs.hidden_states,
998
+ attentions=outputs.attentions,
999
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b463d1df77bc9a3a3395099f491550d098f41f7850aaf6712f2d2df640c4f9a
3
+ size 1906060473