lhallee commited on
Commit
700924b
·
verified ·
1 Parent(s): 43d3de2

Update modeling_fastesm.py

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +553 -528
modeling_fastesm.py CHANGED
@@ -1,528 +1,553 @@
1
-
2
- import torch
3
- import torch.nn as nn
4
- from torch.nn import functional as F
5
- from typing import Optional, Tuple, Union
6
- from einops import rearrange
7
- from transformers.modeling_outputs import (
8
- MaskedLMOutput,
9
- BaseModelOutputWithPastAndCrossAttentions,
10
- BaseModelOutputWithPoolingAndCrossAttentions,
11
- SequenceClassifierOutput,
12
- TokenClassifierOutput
13
- )
14
- from transformers.models.esm.modeling_esm import (
15
- RotaryEmbedding,
16
- EsmContactPredictionHead,
17
- EsmIntermediate,
18
- EsmOutput,
19
- EsmPooler,
20
- EsmLMHead,
21
- EsmSelfOutput,
22
- EsmClassificationHead,
23
- EsmPreTrainedModel,
24
- create_position_ids_from_input_ids,
25
- gelu
26
- )
27
-
28
-
29
- class EsmEmbeddings(nn.Module):
30
- """
31
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
32
- """
33
-
34
- def __init__(self, config):
35
- super().__init__()
36
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
37
- if config.emb_layer_norm_before:
38
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
39
- else:
40
- self.layer_norm = None
41
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
42
- self.register_buffer(
43
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
44
- )
45
-
46
- self.padding_idx = config.pad_token_id
47
- self.position_embeddings = nn.Embedding(
48
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
49
- )
50
- # Token dropout does not work correctly so we disable it
51
- # self.token_dropout = config.token_dropout
52
- self.mask_token_id = config.mask_token_id
53
-
54
- def forward(
55
- self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
56
- ):
57
- if position_ids is None:
58
- if input_ids is not None:
59
- # Create the position ids from the input token ids. Any padded tokens remain padded.
60
- position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
61
- else:
62
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
63
-
64
- if inputs_embeds is None:
65
- inputs_embeds = self.word_embeddings(input_ids)
66
-
67
- embeddings = inputs_embeds
68
-
69
- if self.position_embedding_type == "absolute":
70
- position_embeddings = self.position_embeddings(position_ids)
71
- embeddings = embeddings + position_embeddings
72
-
73
- if self.layer_norm is not None:
74
- embeddings = self.layer_norm(embeddings)
75
- if attention_mask is not None:
76
- embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
77
- return embeddings
78
-
79
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
80
- """
81
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
82
-
83
- Args:
84
- inputs_embeds: torch.Tensor
85
-
86
- Returns: torch.Tensor
87
- """
88
- input_shape = inputs_embeds.size()[:-1]
89
- sequence_length = input_shape[1]
90
-
91
- position_ids = torch.arange(
92
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
93
- )
94
- return position_ids.unsqueeze(0).expand(input_shape)
95
-
96
-
97
- class EsmSelfAttention(nn.Module):
98
- def __init__(self, config, position_embedding_type=None):
99
- super().__init__()
100
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
101
- raise ValueError(
102
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
103
- f"heads ({config.num_attention_heads})"
104
- )
105
-
106
- self.num_attention_heads = config.num_attention_heads
107
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
108
- self.all_head_size = self.num_attention_heads * self.attention_head_size
109
-
110
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
111
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
112
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
113
- self.scale = self.attention_head_size**-0.5
114
-
115
- self.dropout_prob = config.attention_probs_dropout_prob
116
- self.position_embedding_type = position_embedding_type or getattr(
117
- config, "position_embedding_type", "absolute"
118
- )
119
- self.rotary_embeddings = None
120
- if self.position_embedding_type == "rotary":
121
- self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
122
-
123
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
124
- return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
125
-
126
- def forward(
127
- self,
128
- hidden_states: torch.Tensor,
129
- attention_mask: Optional[torch.FloatTensor] = None,
130
- ) -> Tuple[torch.Tensor]:
131
- query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
132
- key_layer = self.transpose_for_scores(self.key(hidden_states))
133
- value_layer = self.transpose_for_scores(self.value(hidden_states))
134
-
135
- if self.position_embedding_type == "rotary":
136
- query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
137
-
138
- context_layer = F.scaled_dot_product_attention(
139
- query_layer,
140
- key_layer,
141
- value_layer,
142
- attn_mask=attention_mask,
143
- dropout_p=self.dropout_prob,
144
- scale=1.0
145
- )
146
- return rearrange(context_layer, 'b h s d -> b s (h d)')
147
-
148
-
149
- class EsmAttention(nn.Module):
150
- def __init__(self, config):
151
- super().__init__()
152
- self.self = EsmSelfAttention(config)
153
- self.output = EsmSelfOutput(config)
154
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
155
-
156
- def forward(
157
- self,
158
- hidden_states,
159
- attention_mask=None,
160
- ):
161
- hidden_states_ln = self.LayerNorm(hidden_states)
162
- attention_output = self.self(
163
- hidden_states_ln,
164
- attention_mask,
165
- )
166
- return self.output(attention_output, hidden_states)
167
-
168
-
169
- class EsmLayer(nn.Module):
170
- def __init__(self, config):
171
- super().__init__()
172
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
173
- self.seq_len_dim = 1
174
- self.attention = EsmAttention(config)
175
- self.intermediate = EsmIntermediate(config)
176
- self.output = EsmOutput(config)
177
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
178
-
179
- def forward(
180
- self,
181
- hidden_states,
182
- attention_mask=None,
183
- ):
184
- attention_output = self.attention(
185
- hidden_states,
186
- attention_mask,
187
- )
188
- layer_output = self.feed_forward_chunk(attention_output)
189
- return layer_output
190
-
191
- def feed_forward_chunk(self, attention_output):
192
- attention_output_ln = self.LayerNorm(attention_output)
193
- intermediate_output = self.intermediate(attention_output_ln)
194
- layer_output = self.output(intermediate_output, attention_output)
195
- return layer_output
196
-
197
-
198
- class EsmEncoder(nn.Module):
199
- def __init__(self, config):
200
- super().__init__()
201
- self.config = config
202
- self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
203
- self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
204
- self.gradient_checkpointing = False
205
-
206
- def forward(
207
- self,
208
- hidden_states,
209
- attention_mask=None,
210
- output_hidden_states=False,
211
- ):
212
- all_hidden_states = () if output_hidden_states else None
213
- for layer_module in self.layer:
214
- if output_hidden_states:
215
- all_hidden_states = all_hidden_states + (hidden_states,)
216
-
217
- if self.gradient_checkpointing and self.training:
218
- hidden_states = self._gradient_checkpointing_func(
219
- layer_module.__call__,
220
- hidden_states,
221
- attention_mask,
222
- )
223
- else:
224
- hidden_states = layer_module(
225
- hidden_states,
226
- attention_mask,
227
- )
228
-
229
- if self.emb_layer_norm_after:
230
- hidden_states = self.emb_layer_norm_after(hidden_states)
231
-
232
- if output_hidden_states:
233
- all_hidden_states = all_hidden_states + (hidden_states,)
234
-
235
- return BaseModelOutputWithPastAndCrossAttentions(
236
- last_hidden_state=hidden_states,
237
- hidden_states=all_hidden_states,
238
- )
239
-
240
-
241
- class FastEsmModel(EsmPreTrainedModel):
242
- def __init__(self, config, add_pooling_layer=True):
243
- super().__init__(config)
244
- self.config = config
245
- self.embeddings = EsmEmbeddings(config)
246
- self.encoder = EsmEncoder(config)
247
- self.pooler = EsmPooler(config) if add_pooling_layer else None
248
- self.contact_head = EsmContactPredictionHead(
249
- in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
250
- )
251
- # Initialize weights and apply final processing
252
- self.post_init()
253
-
254
- def get_input_embeddings(self):
255
- return self.embeddings.word_embeddings
256
-
257
- def set_input_embeddings(self, value):
258
- self.embeddings.word_embeddings = value
259
-
260
- def forward(
261
- self,
262
- input_ids: Optional[torch.Tensor] = None,
263
- attention_mask: Optional[torch.Tensor] = None,
264
- position_ids: Optional[torch.Tensor] = None,
265
- inputs_embeds: Optional[torch.Tensor] = None,
266
- output_hidden_states: Optional[bool] = None,
267
- output_attentions: Optional[bool] = None,
268
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
269
- if output_attentions is not None:
270
- raise ValueError("output_attentions is not supported by F.scaled_dot_product_attention")
271
- output_hidden_states = (
272
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
273
- )
274
- if input_ids is not None and inputs_embeds is not None:
275
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
276
- elif input_ids is not None:
277
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
278
- input_shape = input_ids.size()
279
- elif inputs_embeds is not None:
280
- input_shape = inputs_embeds.size()[:-1]
281
- else:
282
- raise ValueError("You have to specify either input_ids or inputs_embeds")
283
-
284
- batch_size, seq_length = input_shape
285
- embedding_output = self.embeddings(
286
- input_ids=input_ids,
287
- position_ids=position_ids,
288
- attention_mask=attention_mask,
289
- inputs_embeds=inputs_embeds,
290
- )
291
- # Prepare attention mask
292
- if attention_mask is not None:
293
- # attention_mask shape should be (batch_size, 1, 1, seq_length)
294
- # Expand to (batch_size, 1, seq_length, seq_length)
295
- extended_attention_mask = attention_mask[:, None, None, :].expand(
296
- batch_size, 1, seq_length, seq_length
297
- )
298
- # Convert mask to float with 0.0 for positions to keep and -inf for masked positions
299
- attention_mask = attention_mask.to(dtype=embedding_output.dtype) # fp16 compatibility
300
- attention_mask = (1.0 - attention_mask) * torch.finfo(embedding_output.dtype).min
301
- else:
302
- extended_attention_mask = None
303
-
304
- encoder_outputs = self.encoder(
305
- embedding_output,
306
- attention_mask=extended_attention_mask,
307
- output_hidden_states=output_hidden_states,
308
- )
309
- sequence_output = encoder_outputs.last_hidden_state
310
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
311
-
312
- return BaseModelOutputWithPoolingAndCrossAttentions(
313
- last_hidden_state=sequence_output,
314
- pooler_output=pooled_output,
315
- hidden_states=encoder_outputs.hidden_states,
316
- )
317
-
318
-
319
- class FastEsmForMaskedLM(EsmPreTrainedModel):
320
- _tied_weights_keys = ["lm_head.decoder.weight"]
321
-
322
- def __init__(self, config):
323
- super().__init__(config)
324
- self.esm = FastEsmModel(config, add_pooling_layer=False)
325
- self.lm_head = EsmLMHead(config)
326
- self.loss_fct = nn.CrossEntropyLoss()
327
- self.init_weights()
328
-
329
- def get_output_embeddings(self):
330
- return self.lm_head.decoder
331
-
332
- def set_output_embeddings(self, new_embeddings):
333
- self.lm_head.decoder = new_embeddings
334
-
335
- def forward(
336
- self,
337
- input_ids: Optional[torch.LongTensor] = None,
338
- attention_mask: Optional[torch.Tensor] = None,
339
- position_ids: Optional[torch.LongTensor] = None,
340
- inputs_embeds: Optional[torch.FloatTensor] = None,
341
- labels: Optional[torch.LongTensor] = None,
342
- output_attentions: Optional[bool] = None,
343
- output_hidden_states: Optional[bool] = None,
344
- ) -> Union[Tuple, MaskedLMOutput]:
345
- outputs = self.esm(
346
- input_ids,
347
- attention_mask=attention_mask,
348
- position_ids=position_ids,
349
- inputs_embeds=inputs_embeds,
350
- output_hidden_states=output_hidden_states,
351
- output_attentions=output_attentions,
352
- )
353
- sequence_output = outputs.last_hidden_state
354
- prediction_scores = self.lm_head(sequence_output)
355
-
356
- loss = None
357
- if labels is not None:
358
- labels = labels.to(prediction_scores.device)
359
- loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
360
-
361
- return MaskedLMOutput(
362
- loss=loss,
363
- logits=prediction_scores,
364
- hidden_states=outputs.hidden_states,
365
- )
366
-
367
- def predict_contacts(self, tokens, attention_mask):
368
- raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
369
-
370
-
371
- class FastEsmForSequenceClassification(EsmPreTrainedModel):
372
- def __init__(self, config):
373
- super().__init__(config)
374
- self.num_labels = config.num_labels
375
- self.config = config
376
- self.esm = FastEsmModel(config, add_pooling_layer=False)
377
- self.classifier = EsmClassificationHead(config)
378
- self.mse = nn.MSELoss()
379
- self.ce = nn.CrossEntropyLoss()
380
- self.bce = nn.BCEWithLogitsLoss()
381
- self.init_weights()
382
-
383
- def forward(
384
- self,
385
- input_ids: Optional[torch.LongTensor] = None,
386
- attention_mask: Optional[torch.Tensor] = None,
387
- position_ids: Optional[torch.LongTensor] = None,
388
- inputs_embeds: Optional[torch.FloatTensor] = None,
389
- labels: Optional[torch.LongTensor] = None,
390
- output_attentions: Optional[bool] = None,
391
- output_hidden_states: Optional[bool] = None,
392
- ) -> Union[Tuple, SequenceClassifierOutput]:
393
- outputs = self.esm(
394
- input_ids,
395
- attention_mask=attention_mask,
396
- position_ids=position_ids,
397
- inputs_embeds=inputs_embeds,
398
- output_attentions=output_attentions,
399
- output_hidden_states=output_hidden_states,
400
- )
401
- sequence_output = outputs.last_hidden_state
402
- logits = self.classifier(sequence_output)
403
-
404
- loss = None
405
- if labels is not None:
406
- labels = labels.to(logits.device)
407
- if self.config.problem_type is None:
408
- if self.num_labels == 1:
409
- self.config.problem_type = "regression"
410
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
411
- self.config.problem_type = "single_label_classification"
412
- else:
413
- self.config.problem_type = "multi_label_classification"
414
-
415
- if self.config.problem_type == "regression":
416
- if self.num_labels == 1:
417
- loss = self.mse(logits.squeeze(), labels.squeeze())
418
- else:
419
- loss = self.mse(logits, labels)
420
- elif self.config.problem_type == "single_label_classification":
421
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
422
- elif self.config.problem_type == "multi_label_classification":
423
- loss = self.bce(logits, labels)
424
-
425
- return SequenceClassifierOutput(
426
- loss=loss,
427
- logits=logits,
428
- hidden_states=outputs.hidden_states,
429
- )
430
-
431
-
432
- class FastEsmForTokenClassification(EsmPreTrainedModel):
433
- def __init__(self, config):
434
- super().__init__(config)
435
- self.num_labels = config.num_labels
436
- self.esm = FastEsmModel(config, add_pooling_layer=False)
437
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
438
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
439
- self.loss_fct = nn.CrossEntropyLoss()
440
- self.init_weights()
441
-
442
- def forward(
443
- self,
444
- input_ids: Optional[torch.LongTensor] = None,
445
- attention_mask: Optional[torch.Tensor] = None,
446
- position_ids: Optional[torch.LongTensor] = None,
447
- inputs_embeds: Optional[torch.FloatTensor] = None,
448
- labels: Optional[torch.LongTensor] = None,
449
- output_attentions: Optional[bool] = None,
450
- output_hidden_states: Optional[bool] = None,
451
- ) -> Union[Tuple, TokenClassifierOutput]:
452
- outputs = self.esm(
453
- input_ids,
454
- attention_mask=attention_mask,
455
- position_ids=position_ids,
456
- inputs_embeds=inputs_embeds,
457
- output_attentions=output_attentions,
458
- output_hidden_states=output_hidden_states,
459
- )
460
- sequence_output = outputs.last_hidden_state
461
- sequence_output = self.dropout(sequence_output)
462
- logits = self.classifier(sequence_output)
463
-
464
- loss = None
465
- if labels is not None:
466
- labels = labels.to(logits.device)
467
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
468
-
469
- return TokenClassifierOutput(
470
- loss=loss,
471
- logits=logits,
472
- hidden_states=outputs.hidden_states,
473
- )
474
-
475
-
476
- if __name__ == "__main__":
477
- """
478
- Test the hidden state differences between the FastEsmModel and the HF EsmModel.
479
- In full precision, the differences are very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
480
- In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
481
- """
482
- import random
483
- from transformers import EsmModel as TransformersEsmModel, EsmTokenizer
484
-
485
- model_paths = [
486
- "facebook/esm2_t6_8M_UR50D",
487
- "facebook/esm2_t12_35M_UR50D",
488
- "facebook/esm2_t30_150M_UR50D",
489
- "facebook/esm2_t33_650M_UR50D",
490
- ]
491
- canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
492
- length = 64
493
- seq_count = 100
494
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
495
- tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
496
-
497
- def generate_random_sequence(length: int) -> str:
498
- return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
499
-
500
- print("Percentage of hidden states that are within the tolerance:")
501
- for model_path in model_paths:
502
- print(f"Testing {model_path}...")
503
- tokenizer = EsmTokenizer.from_pretrained(model_path)
504
- fast_model = FastEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
505
- model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
506
-
507
- counts = [0] * len(tolerances)
508
- for _ in range(seq_count):
509
- example_seq = generate_random_sequence(length)
510
- fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
511
- fast_output = fast_model(fast_tokens).last_hidden_state.detach().cpu()
512
-
513
- model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
514
- model_output = model(model_tokens).last_hidden_state.detach().cpu()
515
-
516
- for i, atol in enumerate(tolerances):
517
- if torch.allclose(fast_output, model_output, atol=atol):
518
- counts[i] += 1
519
-
520
- print(f"{model_path}:")
521
- for i, atol in enumerate(tolerances):
522
- print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
523
-
524
- model.cpu()
525
- fast_model.cpu()
526
- del model
527
- del fast_model
528
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from typing import Optional, Tuple, Union
5
+ from einops import rearrange
6
+ from transformers import PreTrainedModel
7
+ from transformers.modeling_outputs import (
8
+ MaskedLMOutput,
9
+ BaseModelOutputWithPastAndCrossAttentions,
10
+ BaseModelOutputWithPoolingAndCrossAttentions,
11
+ SequenceClassifierOutput,
12
+ TokenClassifierOutput
13
+ )
14
+ from transformers.models.esm.modeling_esm import (
15
+ RotaryEmbedding,
16
+ EsmContactPredictionHead,
17
+ EsmIntermediate,
18
+ EsmOutput,
19
+ EsmPooler,
20
+ EsmLMHead,
21
+ EsmSelfOutput,
22
+ EsmClassificationHead,
23
+ create_position_ids_from_input_ids,
24
+ )
25
+ from .config_fastesm import FastEsmConfig
26
+
27
+
28
+ class EsmEmbeddings(nn.Module):
29
+ """
30
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
31
+ """
32
+
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
36
+ if config.emb_layer_norm_before:
37
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
38
+ else:
39
+ self.layer_norm = None
40
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
41
+ self.register_buffer(
42
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
43
+ )
44
+
45
+ self.padding_idx = config.pad_token_id
46
+ self.position_embeddings = nn.Embedding(
47
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
48
+ )
49
+ # Token dropout does not work correctly so we disable it
50
+ # self.token_dropout = config.token_dropout
51
+ self.mask_token_id = config.mask_token_id
52
+
53
+ def forward(
54
+ self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
55
+ ):
56
+ if position_ids is None:
57
+ if input_ids is not None:
58
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
59
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
60
+ else:
61
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
62
+
63
+ if inputs_embeds is None:
64
+ inputs_embeds = self.word_embeddings(input_ids)
65
+
66
+ embeddings = inputs_embeds
67
+
68
+ if self.position_embedding_type == "absolute":
69
+ position_embeddings = self.position_embeddings(position_ids)
70
+ embeddings = embeddings + position_embeddings
71
+
72
+ if self.layer_norm is not None:
73
+ embeddings = self.layer_norm(embeddings)
74
+ if attention_mask is not None:
75
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
76
+ return embeddings
77
+
78
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
79
+ """
80
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
81
+
82
+ Args:
83
+ inputs_embeds: torch.Tensor
84
+
85
+ Returns: torch.Tensor
86
+ """
87
+ input_shape = inputs_embeds.size()[:-1]
88
+ sequence_length = input_shape[1]
89
+
90
+ position_ids = torch.arange(
91
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
92
+ )
93
+ return position_ids.unsqueeze(0).expand(input_shape)
94
+
95
+
96
+ class EsmSelfAttention(nn.Module):
97
+ def __init__(self, config, position_embedding_type=None):
98
+ super().__init__()
99
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
100
+ raise ValueError(
101
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
102
+ f"heads ({config.num_attention_heads})"
103
+ )
104
+
105
+ self.num_attention_heads = config.num_attention_heads
106
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
107
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
108
+
109
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
110
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
111
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
112
+ self.scale = self.attention_head_size**-0.5
113
+
114
+ self.dropout_prob = config.attention_probs_dropout_prob
115
+ self.position_embedding_type = position_embedding_type or getattr(
116
+ config, "position_embedding_type", "absolute"
117
+ )
118
+ self.rotary_embeddings = None
119
+ if self.position_embedding_type == "rotary":
120
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
121
+
122
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
123
+ return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
124
+
125
+ def forward(
126
+ self,
127
+ hidden_states: torch.Tensor,
128
+ attention_mask: Optional[torch.FloatTensor] = None,
129
+ ) -> Tuple[torch.Tensor]:
130
+ query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
131
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
132
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
133
+
134
+ if self.position_embedding_type == "rotary":
135
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
136
+
137
+ context_layer = F.scaled_dot_product_attention(
138
+ query_layer,
139
+ key_layer,
140
+ value_layer,
141
+ attn_mask=attention_mask,
142
+ dropout_p=self.dropout_prob,
143
+ scale=1.0
144
+ )
145
+ return rearrange(context_layer, 'b h s d -> b s (h d)')
146
+
147
+
148
+ class EsmAttention(nn.Module):
149
+ def __init__(self, config):
150
+ super().__init__()
151
+ self.self = EsmSelfAttention(config)
152
+ self.output = EsmSelfOutput(config)
153
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
154
+
155
+ def forward(
156
+ self,
157
+ hidden_states,
158
+ attention_mask=None,
159
+ ):
160
+ hidden_states_ln = self.LayerNorm(hidden_states)
161
+ attention_output = self.self(
162
+ hidden_states_ln,
163
+ attention_mask,
164
+ )
165
+ return self.output(attention_output, hidden_states)
166
+
167
+
168
+ class EsmLayer(nn.Module):
169
+ def __init__(self, config):
170
+ super().__init__()
171
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
172
+ self.seq_len_dim = 1
173
+ self.attention = EsmAttention(config)
174
+ self.intermediate = EsmIntermediate(config)
175
+ self.output = EsmOutput(config)
176
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states,
181
+ attention_mask=None,
182
+ ):
183
+ attention_output = self.attention(
184
+ hidden_states,
185
+ attention_mask,
186
+ )
187
+ layer_output = self.feed_forward_chunk(attention_output)
188
+ return layer_output
189
+
190
+ def feed_forward_chunk(self, attention_output):
191
+ attention_output_ln = self.LayerNorm(attention_output)
192
+ intermediate_output = self.intermediate(attention_output_ln)
193
+ layer_output = self.output(intermediate_output, attention_output)
194
+ return layer_output
195
+
196
+
197
+ class EsmEncoder(nn.Module):
198
+ def __init__(self, config):
199
+ super().__init__()
200
+ self.config = config
201
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
202
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
203
+ self.gradient_checkpointing = False
204
+
205
+ def forward(
206
+ self,
207
+ hidden_states,
208
+ attention_mask=None,
209
+ output_hidden_states=False,
210
+ ):
211
+ all_hidden_states = () if output_hidden_states else None
212
+ for layer_module in self.layer:
213
+ if output_hidden_states:
214
+ all_hidden_states = all_hidden_states + (hidden_states,)
215
+
216
+ if self.gradient_checkpointing and self.training:
217
+ hidden_states = self._gradient_checkpointing_func(
218
+ layer_module.__call__,
219
+ hidden_states,
220
+ attention_mask,
221
+ )
222
+ else:
223
+ hidden_states = layer_module(
224
+ hidden_states,
225
+ attention_mask,
226
+ )
227
+
228
+ if self.emb_layer_norm_after:
229
+ hidden_states = self.emb_layer_norm_after(hidden_states)
230
+
231
+ if output_hidden_states:
232
+ all_hidden_states = all_hidden_states + (hidden_states,)
233
+
234
+ return BaseModelOutputWithPastAndCrossAttentions(
235
+ last_hidden_state=hidden_states,
236
+ hidden_states=all_hidden_states,
237
+ )
238
+
239
+
240
+ class FastEsmPreTrainedModel(PreTrainedModel):
241
+ """
242
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
243
+ models.
244
+ """
245
+ config_class = FastEsmConfig
246
+ base_model_prefix = "fastesm"
247
+ supports_gradient_checkpointing = True
248
+ def _init_weights(self, module):
249
+ """Initialize the weights"""
250
+ if isinstance(module, nn.Linear):
251
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
252
+ if module.bias is not None:
253
+ module.bias.data.zero_()
254
+ elif isinstance(module, nn.Embedding):
255
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
256
+ if module.padding_idx is not None:
257
+ module.weight.data[module.padding_idx].zero_()
258
+ elif isinstance(module, nn.LayerNorm):
259
+ module.bias.data.zero_()
260
+ module.weight.data.fill_(1.0)
261
+
262
+ def get_input_embeddings(self) -> nn.Module:
263
+ try:
264
+ return self.embeddings.word_embeddings
265
+ except AttributeError:
266
+ return self.esm.embeddings.word_embeddings
267
+
268
+
269
+ class FastEsmModel(FastEsmPreTrainedModel):
270
+ def __init__(self, config, add_pooling_layer=True):
271
+ super().__init__(config)
272
+ self.config = config
273
+ self.embeddings = EsmEmbeddings(config)
274
+ self.encoder = EsmEncoder(config)
275
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
276
+ # Initialize weights and apply final processing
277
+ self.post_init()
278
+
279
+ def get_input_embeddings(self):
280
+ return self.embeddings.word_embeddings
281
+
282
+ def set_input_embeddings(self, value):
283
+ self.embeddings.word_embeddings = value
284
+
285
+ def forward(
286
+ self,
287
+ input_ids: Optional[torch.Tensor] = None,
288
+ attention_mask: Optional[torch.Tensor] = None,
289
+ position_ids: Optional[torch.Tensor] = None,
290
+ inputs_embeds: Optional[torch.Tensor] = None,
291
+ output_hidden_states: Optional[bool] = None,
292
+ output_attentions: Optional[bool] = None,
293
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
294
+ if output_attentions is not None:
295
+ raise ValueError("output_attentions is not supported by F.scaled_dot_product_attention")
296
+ output_hidden_states = (
297
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
298
+ )
299
+ if input_ids is not None and inputs_embeds is not None:
300
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
301
+ elif input_ids is not None:
302
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
303
+ input_shape = input_ids.size()
304
+ elif inputs_embeds is not None:
305
+ input_shape = inputs_embeds.size()[:-1]
306
+ else:
307
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
308
+
309
+ batch_size, seq_length = input_shape
310
+ embedding_output = self.embeddings(
311
+ input_ids=input_ids,
312
+ position_ids=position_ids,
313
+ attention_mask=attention_mask,
314
+ inputs_embeds=inputs_embeds,
315
+ )
316
+ # Prepare attention mask
317
+ if attention_mask is not None:
318
+ # attention_mask shape should be (batch_size, 1, 1, seq_length)
319
+ # Expand to (batch_size, 1, seq_length, seq_length)
320
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
321
+ batch_size, 1, seq_length, seq_length
322
+ )
323
+ # Convert mask to float with 0.0 for positions to keep and -inf for masked positions
324
+ attention_mask = attention_mask.to(dtype=embedding_output.dtype) # fp16 compatibility
325
+ attention_mask = (1.0 - attention_mask) * torch.finfo(embedding_output.dtype).min
326
+ else:
327
+ extended_attention_mask = None
328
+
329
+ encoder_outputs = self.encoder(
330
+ embedding_output,
331
+ attention_mask=extended_attention_mask,
332
+ output_hidden_states=output_hidden_states,
333
+ )
334
+ sequence_output = encoder_outputs.last_hidden_state
335
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
336
+
337
+ return BaseModelOutputWithPoolingAndCrossAttentions(
338
+ last_hidden_state=sequence_output,
339
+ pooler_output=pooled_output,
340
+ hidden_states=encoder_outputs.hidden_states,
341
+ )
342
+
343
+
344
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
345
+ _tied_weights_keys = ["lm_head.decoder.weight"]
346
+
347
+ def __init__(self, config):
348
+ super().__init__(config)
349
+ self.esm = FastEsmModel(config, add_pooling_layer=False)
350
+ self.lm_head = EsmLMHead(config)
351
+ self.loss_fct = nn.CrossEntropyLoss()
352
+ self.init_weights()
353
+
354
+ def get_output_embeddings(self):
355
+ return self.lm_head.decoder
356
+
357
+ def set_output_embeddings(self, new_embeddings):
358
+ self.lm_head.decoder = new_embeddings
359
+
360
+ def forward(
361
+ self,
362
+ input_ids: Optional[torch.LongTensor] = None,
363
+ attention_mask: Optional[torch.Tensor] = None,
364
+ position_ids: Optional[torch.LongTensor] = None,
365
+ inputs_embeds: Optional[torch.FloatTensor] = None,
366
+ labels: Optional[torch.LongTensor] = None,
367
+ output_attentions: Optional[bool] = None,
368
+ output_hidden_states: Optional[bool] = None,
369
+ ) -> Union[Tuple, MaskedLMOutput]:
370
+ outputs = self.esm(
371
+ input_ids,
372
+ attention_mask=attention_mask,
373
+ position_ids=position_ids,
374
+ inputs_embeds=inputs_embeds,
375
+ output_hidden_states=output_hidden_states,
376
+ output_attentions=output_attentions,
377
+ )
378
+ sequence_output = outputs.last_hidden_state
379
+ prediction_scores = self.lm_head(sequence_output)
380
+
381
+ loss = None
382
+ if labels is not None:
383
+ labels = labels.to(prediction_scores.device)
384
+ loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
385
+
386
+ return MaskedLMOutput(
387
+ loss=loss,
388
+ logits=prediction_scores,
389
+ hidden_states=outputs.hidden_states,
390
+ )
391
+
392
+ def predict_contacts(self, tokens, attention_mask):
393
+ raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
394
+
395
+
396
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
397
+ def __init__(self, config):
398
+ super().__init__(config)
399
+ self.num_labels = config.num_labels
400
+ self.config = config
401
+ self.esm = FastEsmModel(config, add_pooling_layer=False)
402
+ self.classifier = EsmClassificationHead(config)
403
+ self.mse = nn.MSELoss()
404
+ self.ce = nn.CrossEntropyLoss()
405
+ self.bce = nn.BCEWithLogitsLoss()
406
+ self.init_weights()
407
+
408
+ def forward(
409
+ self,
410
+ input_ids: Optional[torch.LongTensor] = None,
411
+ attention_mask: Optional[torch.Tensor] = None,
412
+ position_ids: Optional[torch.LongTensor] = None,
413
+ inputs_embeds: Optional[torch.FloatTensor] = None,
414
+ labels: Optional[torch.LongTensor] = None,
415
+ output_attentions: Optional[bool] = None,
416
+ output_hidden_states: Optional[bool] = None,
417
+ ) -> Union[Tuple, SequenceClassifierOutput]:
418
+ outputs = self.esm(
419
+ input_ids,
420
+ attention_mask=attention_mask,
421
+ position_ids=position_ids,
422
+ inputs_embeds=inputs_embeds,
423
+ output_attentions=output_attentions,
424
+ output_hidden_states=output_hidden_states,
425
+ )
426
+ sequence_output = outputs.last_hidden_state
427
+ logits = self.classifier(sequence_output)
428
+
429
+ loss = None
430
+ if labels is not None:
431
+ labels = labels.to(logits.device)
432
+ if self.config.problem_type is None:
433
+ if self.num_labels == 1:
434
+ self.config.problem_type = "regression"
435
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
436
+ self.config.problem_type = "single_label_classification"
437
+ else:
438
+ self.config.problem_type = "multi_label_classification"
439
+
440
+ if self.config.problem_type == "regression":
441
+ if self.num_labels == 1:
442
+ loss = self.mse(logits.squeeze(), labels.squeeze())
443
+ else:
444
+ loss = self.mse(logits, labels)
445
+ elif self.config.problem_type == "single_label_classification":
446
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
447
+ elif self.config.problem_type == "multi_label_classification":
448
+ loss = self.bce(logits, labels)
449
+
450
+ return SequenceClassifierOutput(
451
+ loss=loss,
452
+ logits=logits,
453
+ hidden_states=outputs.hidden_states,
454
+ )
455
+
456
+
457
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
458
+ def __init__(self, config):
459
+ super().__init__(config)
460
+ self.num_labels = config.num_labels
461
+ self.esm = FastEsmModel(config, add_pooling_layer=False)
462
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
463
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
464
+ self.loss_fct = nn.CrossEntropyLoss()
465
+ self.init_weights()
466
+
467
+ def forward(
468
+ self,
469
+ input_ids: Optional[torch.LongTensor] = None,
470
+ attention_mask: Optional[torch.Tensor] = None,
471
+ position_ids: Optional[torch.LongTensor] = None,
472
+ inputs_embeds: Optional[torch.FloatTensor] = None,
473
+ labels: Optional[torch.LongTensor] = None,
474
+ output_attentions: Optional[bool] = None,
475
+ output_hidden_states: Optional[bool] = None,
476
+ ) -> Union[Tuple, TokenClassifierOutput]:
477
+ outputs = self.esm(
478
+ input_ids,
479
+ attention_mask=attention_mask,
480
+ position_ids=position_ids,
481
+ inputs_embeds=inputs_embeds,
482
+ output_attentions=output_attentions,
483
+ output_hidden_states=output_hidden_states,
484
+ )
485
+ sequence_output = outputs.last_hidden_state
486
+ sequence_output = self.dropout(sequence_output)
487
+ logits = self.classifier(sequence_output)
488
+
489
+ loss = None
490
+ if labels is not None:
491
+ labels = labels.to(logits.device)
492
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
493
+
494
+ return TokenClassifierOutput(
495
+ loss=loss,
496
+ logits=logits,
497
+ hidden_states=outputs.hidden_states,
498
+ )
499
+
500
+
501
+ if __name__ == "__main__":
502
+ """
503
+ Test the hidden state differences between the FastEsmModel and the HF EsmModel.
504
+ In full precision, the differences are very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
505
+ In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
506
+ """
507
+ import random
508
+ from transformers import EsmModel as TransformersEsmModel, EsmTokenizer
509
+
510
+ model_paths = [
511
+ "facebook/esm2_t6_8M_UR50D",
512
+ "facebook/esm2_t12_35M_UR50D",
513
+ "facebook/esm2_t30_150M_UR50D",
514
+ "facebook/esm2_t33_650M_UR50D",
515
+ ]
516
+ canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
517
+ length = 64
518
+ seq_count = 100
519
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
520
+ tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
521
+
522
+ def generate_random_sequence(length: int) -> str:
523
+ return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
524
+
525
+ print("Percentage of hidden states that are within the tolerance:")
526
+ for model_path in model_paths:
527
+ print(f"Testing {model_path}...")
528
+ tokenizer = EsmTokenizer.from_pretrained(model_path)
529
+ fast_model = FastEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
530
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
531
+
532
+ counts = [0] * len(tolerances)
533
+ for _ in range(seq_count):
534
+ example_seq = generate_random_sequence(length)
535
+ fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
536
+ fast_output = fast_model(fast_tokens).last_hidden_state.detach().cpu()
537
+
538
+ model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
539
+ model_output = model(model_tokens).last_hidden_state.detach().cpu()
540
+
541
+ for i, atol in enumerate(tolerances):
542
+ if torch.allclose(fast_output, model_output, atol=atol):
543
+ counts[i] += 1
544
+
545
+ print(f"{model_path}:")
546
+ for i, atol in enumerate(tolerances):
547
+ print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
548
+
549
+ model.cpu()
550
+ fast_model.cpu()
551
+ del model
552
+ del fast_model
553
+ torch.cuda.empty_cache()