ligeti commited on
Commit
72342c2
verified
1 Parent(s): c981028

Upload ProkBertForMaskedLM

Browse files
Files changed (4) hide show
  1. config.json +11 -4
  2. generation_config.json +1 -1
  3. model.safetensors +3 -0
  4. models.py +295 -0
config.json CHANGED
@@ -1,23 +1,30 @@
1
  {
2
- "_name_or_path": "/scratch/c_evolm/trained_models/prokbert-mini-k6s2/checkpoint-4800",
3
  "architectures": [
4
- "MegatronBertForMaskedLM"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
 
 
 
 
 
7
  "hidden_act": "gelu",
8
  "hidden_dropout_prob": 0.1,
9
  "hidden_size": 384,
10
  "initializer_range": 0.02,
11
  "intermediate_size": 4096,
 
12
  "layer_norm_eps": 1e-12,
13
  "max_position_embeddings": 2048,
14
- "model_type": "megatron-bert",
15
  "num_attention_heads": 6,
16
  "num_hidden_layers": 6,
17
  "pad_token_id": 0,
18
  "position_embedding_type": "relative_key_query",
 
19
  "torch_dtype": "float32",
20
- "transformers_version": "4.33.1",
21
  "type_vocab_size": 2,
22
  "use_cache": true,
23
  "vocab_size": 4200
 
1
  {
2
+ "_name_or_path": "/project/c_evolm/huggingface/prokbert-mini-long",
3
  "architectures": [
4
+ "ProkBertForMaskedLM"
5
  ],
6
  "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "models.ProkBertConfig",
9
+ "AutoModelForMaskedLM": "models.ProkBertForMaskedLM"
10
+ },
11
+ "classification_dropout_rate": 0.1,
12
  "hidden_act": "gelu",
13
  "hidden_dropout_prob": 0.1,
14
  "hidden_size": 384,
15
  "initializer_range": 0.02,
16
  "intermediate_size": 4096,
17
+ "kmer": 6,
18
  "layer_norm_eps": 1e-12,
19
  "max_position_embeddings": 2048,
20
+ "model_type": "prokbert",
21
  "num_attention_heads": 6,
22
  "num_hidden_layers": 6,
23
  "pad_token_id": 0,
24
  "position_embedding_type": "relative_key_query",
25
+ "shift": 2,
26
  "torch_dtype": "float32",
27
+ "transformers_version": "4.48.0.dev0",
28
  "type_vocab_size": 2,
29
  "use_cache": true,
30
  "vocab_size": 4200
generation_config.json CHANGED
@@ -1,5 +1,5 @@
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
- "transformers_version": "4.33.1"
5
  }
 
1
  {
2
  "_from_model_config": true,
3
  "pad_token_id": 0,
4
+ "transformers_version": "4.48.0.dev0"
5
  }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94979e16a54e3eabc577c436c81e4f64a4f89fd2317f8cfeff8ad1c7bb545683
3
+ size 106351368
models.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import warnings
3
+ import logging
4
+ from typing import Optional, Tuple, Union
5
+ import os
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import MegatronBertConfig, MegatronBertModel, MegatronBertForMaskedLM, MegatronBertPreTrainedModel, PreTrainedModel
10
+ from transformers.modeling_outputs import SequenceClassifierOutput
11
+ from transformers.utils.hub import cached_file
12
+
13
+
14
+ class BertForBinaryClassificationWithPooling(nn.Module):
15
+ """
16
+ ProkBERT model for binary classification with custom pooling.
17
+
18
+ This model extends a pre-trained `MegatronBertModel` by adding a weighting layer
19
+ to compute a weighted sum over the sequence outputs, followed by a classifier.
20
+
21
+ Attributes:
22
+ base_model (MegatronBertModel): The base BERT model.
23
+ weighting_layer (nn.Linear): Linear layer to compute weights for each token.
24
+ dropout (nn.Dropout): Dropout layer.
25
+ classifier (nn.Linear): Linear layer for classification.
26
+ """
27
+ def __init__(self, base_model: MegatronBertModel):
28
+ """
29
+ Initialize the BertForBinaryClassificationWithPooling model.
30
+
31
+ Args:
32
+ base_model (MegatronBertModel): A pre-trained `MegatronBertModel` instance.
33
+ """
34
+
35
+ super(BertForBinaryClassificationWithPooling, self).__init__()
36
+ self.base_model = base_model
37
+ self.base_model_config_dict = base_model.config.to_dict()
38
+ self.hidden_size = self.base_model_config_dict['hidden_size']
39
+ self.dropout_rate = self.base_model_config_dict['hidden_dropout_prob']
40
+
41
+ self.weighting_layer = nn.Linear(self.hidden_size, 1)
42
+ self.dropout = nn.Dropout(self.dropout_rate)
43
+ self.classifier = nn.Linear(self.hidden_size, 2)
44
+
45
+ def forward(self, input_ids, attention_mask=None, labels=None, output_hidden_states=False, output_pooled_output=False):
46
+ # Modified call to base model to include output_hidden_states
47
+ outputs = self.base_model(input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states)
48
+ sequence_output = outputs[0]
49
+
50
+ # Compute weights for each position in the sequence
51
+ weights = self.weighting_layer(sequence_output)
52
+ weights = torch.nn.functional.softmax(weights, dim=1)
53
+
54
+ # Compute weighted sum
55
+ pooled_output = torch.sum(weights * sequence_output, dim=1)
56
+
57
+ # Classification head
58
+ pooled_output = self.dropout(pooled_output)
59
+ logits = self.classifier(pooled_output)
60
+
61
+ # Prepare the output as a dictionary
62
+ output = {"logits": logits}
63
+
64
+ # Include hidden states in output if requested
65
+ if output_hidden_states:
66
+ output["hidden_states"] = outputs.hidden_states
67
+ if output_pooled_output:
68
+ output["pooled_output"] = pooled_output
69
+
70
+ # If labels are provided, compute the loss
71
+ if labels is not None:
72
+ loss_fct = torch.nn.CrossEntropyLoss()
73
+ loss = loss_fct(logits.view(-1, 2), labels.view(-1))
74
+ output["loss"] = loss
75
+
76
+ return output
77
+
78
+ def save_pretrained(self, save_directory):
79
+ """
80
+ Save the model weights and configuration in a directory.
81
+
82
+ Args:
83
+ save_directory (str): Directory where the model and configuration can be saved.
84
+ """
85
+ print('The save pretrained is called!')
86
+ if not os.path.exists(save_directory):
87
+ os.makedirs(save_directory)
88
+
89
+ model_path = os.path.join(save_directory, "pytorch_model.bin")
90
+ torch.save(self.state_dict(), model_path)
91
+ print(f'The save directory is: {save_directory}')
92
+ self.base_model.config.save_pretrained(save_directory)
93
+
94
+ @classmethod
95
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
96
+ """
97
+ Load the model weights and configuration from a local directory or Hugging Face Hub.
98
+
99
+ Args:
100
+ pretrained_model_name_or_path (str): Directory path where the model and configuration were saved, or name of the model in Hugging Face Hub.
101
+
102
+ Returns:
103
+ model: Instance of BertForBinaryClassificationWithPooling.
104
+ """
105
+ # Determine if the path is local or from Hugging Face Hub
106
+ if os.path.exists(pretrained_model_name_or_path):
107
+ # Path is local
108
+ if 'config' in kwargs:
109
+ print('Config is in the parameters')
110
+ config = kwargs['config']
111
+
112
+ else:
113
+ config = MegatronBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
114
+ base_model = MegatronBertModel(config=config)
115
+ model = cls(base_model=base_model)
116
+ model_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
117
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
118
+ else:
119
+ # Path is from Hugging Face Hub
120
+ config = kwargs.pop('config', None)
121
+ if config is None:
122
+ config = MegatronBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
123
+
124
+ base_model = MegatronBertModel(config=config)
125
+ model = cls(base_model=base_model)
126
+ model_file = cached_file(pretrained_model_name_or_path, "pytorch_model.bin")
127
+ model.load_state_dict(torch.load(model_file, map_location=torch.device('cpu'), weights_only=True))
128
+
129
+ return model
130
+
131
+
132
+
133
+ class OldProkBertConfig(MegatronBertConfig):
134
+
135
+ model_type = "prokbert"
136
+ def __init__(
137
+ self,
138
+ kmer: int = 6,
139
+ shift: int = 1,
140
+ **kwargs,
141
+ ):
142
+ super().__init__(**kwargs)
143
+ self.kmer=kmer
144
+ self.shift=shift
145
+
146
+ class ProkBertConfig(MegatronBertConfig):
147
+ model_type = "prokbert"
148
+
149
+ def __init__(
150
+ self,
151
+ kmer: int = 6,
152
+ shift: int = 1,
153
+ num_labels: int = 2,
154
+ classification_dropout_rate: float = 0.1,
155
+ **kwargs,
156
+ ):
157
+ super().__init__(**kwargs)
158
+ self.kmer = kmer
159
+ self.shift = shift
160
+ self.num_labels = num_labels
161
+ self.classification_dropout_rate = classification_dropout_rate
162
+
163
+
164
+
165
+
166
+ class ProkBertClassificationConfig(ProkBertConfig):
167
+ model_type = "prokbert"
168
+ def __init__(
169
+ self,
170
+ num_labels: int = 2,
171
+ classification_dropout_rate: float = 0.1,
172
+ **kwargs,
173
+ ):
174
+ super().__init__(**kwargs)
175
+ # Ide j枚n majd n茅mi extra l茅p茅s, egyel艖re csak pr贸b谩lkozunk a sima configgal.
176
+ self.num_labels = num_labels
177
+ self.classification_dropout_rate = classification_dropout_rate
178
+
179
+ class ProkBertPreTrainedModel(PreTrainedModel):
180
+ """
181
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
182
+ models.
183
+ """
184
+
185
+ config_class = ProkBertConfig
186
+ base_model_prefix = "bert"
187
+ supports_gradient_checkpointing = True
188
+
189
+ def _init_weights(self, module):
190
+ """Initialize the weights"""
191
+ if isinstance(module, (nn.Linear, nn.Embedding)):
192
+ # Slightly different from the TF version which uses truncated_normal for initialization
193
+ # cf https://github.com/pytorch/pytorch/pull/5617
194
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
195
+ elif isinstance(module, nn.LayerNorm):
196
+ module.bias.data.zero_()
197
+ module.weight.data.fill_(1.0)
198
+ if isinstance(module, nn.Linear) and module.bias is not None:
199
+ module.bias.data.zero_()
200
+
201
+
202
+
203
+
204
+ class ProkBertModel(MegatronBertModel):
205
+ config_class = ProkBertConfig
206
+
207
+ def __init__(self, config: ProkBertConfig, **kwargs):
208
+ if not isinstance(config, ProkBertConfig):
209
+ raise ValueError(f"Expected `ProkBertConfig`, got {config.__class__.__module__}.{config.__class__.__name__}")
210
+
211
+ super().__init__(config, **kwargs)
212
+ self.config = config
213
+ # One should check if it is a prper prokbert config, if not crafting one.
214
+
215
+
216
+ class ProkBertForMaskedLM(MegatronBertForMaskedLM):
217
+ config_class = ProkBertConfig
218
+
219
+ def __init__(self, config: ProkBertConfig, **kwargs):
220
+ if not isinstance(config, ProkBertConfig):
221
+ raise ValueError(f"Expected `ProkBertConfig`, got {config.__class__.__module__}.{config.__class__.__name__}")
222
+
223
+ super().__init__(config, **kwargs)
224
+ self.config = config
225
+ # One should check if it is a prper prokbert config, if not crafting one.
226
+
227
+
228
+ class ProkBertForSequenceClassification(ProkBertPreTrainedModel):
229
+ config_class = ProkBertConfig
230
+ base_model_prefix = "bert"
231
+
232
+ def __init__(self, config):
233
+
234
+ super().__init__(config)
235
+ self.config = config
236
+ self.bert = ProkBertModel(config)
237
+ self.weighting_layer = nn.Linear(self.config.hidden_size, 1)
238
+ self.dropout = nn.Dropout(self.config.classification_dropout_rate)
239
+ self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)
240
+ self.loss_fct = torch.nn.CrossEntropyLoss()
241
+
242
+ self.post_init()
243
+
244
+ def forward(
245
+ self,
246
+ input_ids: Optional[torch.LongTensor] = None,
247
+ attention_mask: Optional[torch.FloatTensor] = None,
248
+ token_type_ids: Optional[torch.LongTensor] = None,
249
+ position_ids: Optional[torch.LongTensor] = None,
250
+ head_mask: Optional[torch.FloatTensor] = None,
251
+ inputs_embeds: Optional[torch.FloatTensor] = None,
252
+ labels: Optional[torch.LongTensor] = None,
253
+ output_attentions: Optional[bool] = None,
254
+ output_hidden_states: Optional[bool] = None,
255
+ return_dict: Optional[bool] = None,
256
+ ) -> Union[Tuple, SequenceClassifierOutput]:
257
+ r"""
258
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
259
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
260
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
261
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
262
+ """
263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
264
+
265
+ outputs = self.bert(
266
+ input_ids,
267
+ attention_mask=attention_mask,
268
+ token_type_ids=token_type_ids,
269
+ position_ids=position_ids,
270
+ head_mask=head_mask,
271
+ inputs_embeds=inputs_embeds,
272
+ output_attentions=output_attentions,
273
+ output_hidden_states=output_hidden_states,
274
+ return_dict=return_dict,
275
+ )
276
+ sequence_output = outputs[0]
277
+
278
+ # Compute weights for each position in the sequence
279
+ weights = self.weighting_layer(sequence_output)
280
+ weights = torch.nn.functional.softmax(weights, dim=1)
281
+ # Compute weighted sum
282
+ pooled_output = torch.sum(weights * sequence_output, dim=1)
283
+ # Classification head
284
+ pooled_output = self.dropout(pooled_output)
285
+ logits = self.classifier(pooled_output)
286
+ loss = self.loss_fct(logits.view(-1, 2), labels.view(-1))
287
+
288
+ classification_output = SequenceClassifierOutput(
289
+ loss=loss,
290
+ logits=logits,
291
+ hidden_states=outputs.hidden_states,
292
+ attentions=outputs.attentions,
293
+ )
294
+ return classification_output
295
+