lhallee commited on
Commit
888b364
·
verified ·
1 Parent(s): bfb370f

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +998 -998
modeling_fastesm.py CHANGED
@@ -1,998 +1,998 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import functional as F
4
- from torch.utils.data import Dataset, DataLoader
5
- from typing import Optional, Tuple, Union
6
- from einops import rearrange
7
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
- from transformers.modeling_outputs import (
9
- MaskedLMOutput,
10
- BaseModelOutputWithPastAndCrossAttentions,
11
- BaseModelOutputWithPoolingAndCrossAttentions,
12
- SequenceClassifierOutput,
13
- TokenClassifierOutput
14
- )
15
- from transformers.models.esm.modeling_esm import (
16
- EsmIntermediate,
17
- EsmOutput,
18
- EsmPooler,
19
- EsmLMHead,
20
- EsmSelfOutput,
21
- EsmClassificationHead,
22
- )
23
- from tqdm.auto import tqdm
24
-
25
-
26
- class FastEsmConfig(PretrainedConfig):
27
- model_type = "fast_esm"
28
- def __init__(
29
- self,
30
- vocab_size=None,
31
- mask_token_id=None,
32
- pad_token_id=None,
33
- hidden_size=768,
34
- num_hidden_layers=12,
35
- num_attention_heads=12,
36
- intermediate_size=3072,
37
- hidden_dropout_prob=0.1,
38
- attention_probs_dropout_prob=0.1,
39
- max_position_embeddings=1026,
40
- initializer_range=0.02,
41
- layer_norm_eps=1e-12,
42
- position_embedding_type="absolute",
43
- emb_layer_norm_before=None,
44
- **kwargs,
45
- ):
46
- super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
47
-
48
- self.vocab_size = vocab_size
49
- self.hidden_size = hidden_size
50
- self.num_hidden_layers = num_hidden_layers
51
- self.num_attention_heads = num_attention_heads
52
- self.intermediate_size = intermediate_size
53
- self.hidden_dropout_prob = hidden_dropout_prob
54
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
55
- self.max_position_embeddings = max_position_embeddings
56
- self.initializer_range = initializer_range
57
- self.layer_norm_eps = layer_norm_eps
58
- self.position_embedding_type = position_embedding_type
59
- self.emb_layer_norm_before = emb_layer_norm_before
60
-
61
- def to_dict(self):
62
- """
63
- Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
64
-
65
- Returns:
66
- `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
67
- """
68
- output = super().to_dict()
69
- return output
70
-
71
-
72
- def rotate_half(x):
73
- x1, x2 = x.chunk(2, dim=-1)
74
- return torch.cat((-x2, x1), dim=-1)
75
-
76
-
77
- def apply_rotary_pos_emb(x, cos, sin):
78
- cos = cos[:, :, : x.shape[-2], :]
79
- sin = sin[:, :, : x.shape[-2], :]
80
-
81
- return (x * cos) + (rotate_half(x) * sin)
82
-
83
-
84
- def symmetrize(x):
85
- "Make layer symmetric in final two dimensions, used for contact prediction."
86
- return x + x.transpose(-1, -2)
87
-
88
-
89
- def average_product_correct(x):
90
- "Perform average product correct, used for contact prediction."
91
- a1 = x.sum(-1, keepdims=True)
92
- a2 = x.sum(-2, keepdims=True)
93
- a12 = x.sum((-1, -2), keepdims=True)
94
-
95
- avg = a1 * a2
96
- avg.div_(a12) # in-place to reduce memory
97
- normalized = x - avg
98
- return normalized
99
-
100
-
101
- class EsmContactPredictionHead(nn.Module):
102
- """Performs symmetrization, apc, and computes a logistic regression on the output features"""
103
-
104
- def __init__(
105
- self,
106
- in_features: int,
107
- bias=True,
108
- eos_idx: int = 2,
109
- ):
110
- super().__init__()
111
- self.in_features = in_features
112
- self.eos_idx = eos_idx
113
- self.regression = nn.Linear(in_features, 1, bias)
114
- self.activation = nn.Sigmoid()
115
-
116
- def forward(self, tokens, attentions):
117
- # remove eos token attentions
118
- eos_mask = tokens.ne(self.eos_idx).to(attentions)
119
- eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
120
- attentions = attentions * eos_mask[:, None, None, :, :]
121
- attentions = attentions[..., :-1, :-1]
122
- # remove cls token attentions
123
- attentions = attentions[..., 1:, 1:]
124
- batch_size, layers, heads, seqlen, _ = attentions.size()
125
- attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
126
-
127
- # features: batch x channels x tokens x tokens (symmetric)
128
- attentions = attentions.to(
129
- self.regression.weight.device
130
- ) # attentions always float32, may need to convert to float16
131
- attentions = average_product_correct(symmetrize(attentions))
132
- attentions = attentions.permute(0, 2, 3, 1)
133
- return self.activation(self.regression(attentions).squeeze(3))
134
-
135
-
136
- class RotaryEmbedding(torch.nn.Module):
137
- """
138
- Rotary position embeddings based on those in
139
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
140
- matrices which depend on their relative positions.
141
- """
142
-
143
- def __init__(self, dim: int):
144
- super().__init__()
145
- # Generate and save the inverse frequency buffer (non trainable)
146
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
147
- inv_freq = inv_freq
148
- self.register_buffer("inv_freq", inv_freq)
149
-
150
- self._seq_len_cached = None
151
- self._cos_cached = None
152
- self._sin_cached = None
153
-
154
- def _update_cos_sin_tables(self, x, seq_dimension=2):
155
- seq_len = x.shape[seq_dimension]
156
-
157
- # Reset the tables if the sequence length has changed,
158
- # or if we're on a new device (possibly due to tracing for instance)
159
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
160
- self._seq_len_cached = seq_len
161
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
162
- freqs = torch.outer(t, self.inv_freq)
163
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
164
-
165
- self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
166
- self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
167
-
168
- return self._cos_cached, self._sin_cached
169
-
170
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
171
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
172
-
173
- return (
174
- apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
175
- apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
176
- )
177
-
178
-
179
- class EsmEmbeddings(nn.Module):
180
- """
181
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
182
- """
183
-
184
- def __init__(self, config):
185
- super().__init__()
186
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
187
- if config.emb_layer_norm_before:
188
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
189
- else:
190
- self.layer_norm = None
191
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
192
- self.register_buffer(
193
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
194
- )
195
-
196
- def forward(
197
- self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
198
- ):
199
- if inputs_embeds is None:
200
- inputs_embeds = self.word_embeddings(input_ids)
201
-
202
- embeddings = inputs_embeds
203
-
204
- if self.layer_norm is not None:
205
- embeddings = self.layer_norm(embeddings)
206
- if attention_mask is not None:
207
- embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
208
- return embeddings
209
-
210
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
211
- """
212
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
213
-
214
- Args:
215
- inputs_embeds: torch.Tensor
216
-
217
- Returns: torch.Tensor
218
- """
219
- input_shape = inputs_embeds.size()[:-1]
220
- sequence_length = input_shape[1]
221
-
222
- position_ids = torch.arange(
223
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
224
- )
225
- return position_ids.unsqueeze(0).expand(input_shape)
226
-
227
-
228
- class EsmSelfAttention(nn.Module):
229
- def __init__(self, config, position_embedding_type=None):
230
- super().__init__()
231
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
232
- raise ValueError(
233
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
234
- f"heads ({config.num_attention_heads})"
235
- )
236
-
237
- self.num_attention_heads = config.num_attention_heads
238
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
239
- self.all_head_size = self.num_attention_heads * self.attention_head_size
240
-
241
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
242
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
243
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
244
- self.scale = self.attention_head_size**-0.5
245
-
246
- self.dropout_prob = config.attention_probs_dropout_prob
247
- self.position_embedding_type = position_embedding_type or getattr(
248
- config, "position_embedding_type", "absolute"
249
- )
250
- self.rotary_embeddings = None
251
- if self.position_embedding_type == "rotary":
252
- self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
253
-
254
- def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
255
- return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
256
-
257
- def forward(
258
- self,
259
- hidden_states: torch.Tensor,
260
- attention_mask: Optional[torch.FloatTensor] = None,
261
- output_attentions: bool = False,
262
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
263
- """Forward pass for self attention.
264
-
265
- Args:
266
- hidden_states: Input tensor
267
- attention_mask: Optional attention mask
268
- output_attentions: Whether to return attention weights
269
-
270
- Returns:
271
- Output tensor and optionally attention weights
272
- """
273
- query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
274
- key_layer = self.transpose_for_scores(self.key(hidden_states))
275
- value_layer = self.transpose_for_scores(self.value(hidden_states))
276
-
277
- if self.position_embedding_type == "rotary":
278
- query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
279
-
280
- if output_attentions:
281
- # Manual attention computation to get attention weights
282
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
- if attention_mask is not None:
284
- attention_scores = attention_scores + attention_mask
285
- attention_probs = F.softmax(attention_scores, dim=-1)
286
- if self.dropout_prob > 0:
287
- attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
288
- context_layer = torch.matmul(attention_probs, value_layer)
289
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
290
- return context_layer, attention_probs
291
- else:
292
- context_layer = F.scaled_dot_product_attention(
293
- query_layer,
294
- key_layer,
295
- value_layer,
296
- attn_mask=attention_mask,
297
- dropout_p=self.dropout_prob,
298
- scale=1.0
299
- )
300
- context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
301
- return context_layer
302
-
303
-
304
- class EsmAttention(nn.Module):
305
- def __init__(self, config):
306
- super().__init__()
307
- self.self = EsmSelfAttention(config)
308
- self.output = EsmSelfOutput(config)
309
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
310
-
311
- def forward(
312
- self,
313
- hidden_states: torch.Tensor,
314
- attention_mask: Optional[torch.FloatTensor] = None,
315
- output_attentions: bool = False,
316
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
317
- """Forward pass for attention layer.
318
-
319
- Args:
320
- hidden_states: Input tensor
321
- attention_mask: Optional attention mask
322
- output_attentions: Whether to return attention weights
323
-
324
- Returns:
325
- Output tensor and optionally attention weights
326
- """
327
- hidden_states_ln = self.LayerNorm(hidden_states)
328
- self_outputs = self.self(
329
- hidden_states_ln,
330
- attention_mask,
331
- output_attentions,
332
- )
333
- if output_attentions:
334
- attention_output, attention_weights = self_outputs
335
- attention_output = self.output(attention_output, hidden_states)
336
- return attention_output, attention_weights
337
- else:
338
- attention_output = self_outputs
339
- return self.output(attention_output, hidden_states)
340
-
341
-
342
- class EsmLayer(nn.Module):
343
- def __init__(self, config):
344
- super().__init__()
345
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
346
- self.seq_len_dim = 1
347
- self.attention = EsmAttention(config)
348
- self.intermediate = EsmIntermediate(config)
349
- self.output = EsmOutput(config)
350
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
351
-
352
- def forward(
353
- self,
354
- hidden_states: torch.Tensor,
355
- attention_mask: Optional[torch.FloatTensor] = None,
356
- output_attentions: bool = False,
357
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
358
- """Forward pass for transformer layer.
359
-
360
- Args:
361
- hidden_states: Input tensor
362
- attention_mask: Optional attention mask
363
- output_attentions: Whether to return attention weights
364
-
365
- Returns:
366
- Output tensor and optionally attention weights
367
- """
368
- attention_outputs = self.attention(
369
- hidden_states,
370
- attention_mask,
371
- output_attentions,
372
- )
373
- if output_attentions:
374
- attention_output, attention_weights = attention_outputs
375
- else:
376
- attention_output = attention_outputs
377
- attention_weights = None
378
-
379
- layer_output = self.feed_forward_chunk(attention_output)
380
-
381
- if output_attentions:
382
- return layer_output, attention_weights
383
- return layer_output
384
-
385
- def feed_forward_chunk(self, attention_output):
386
- attention_output_ln = self.LayerNorm(attention_output)
387
- intermediate_output = self.intermediate(attention_output_ln)
388
- layer_output = self.output(intermediate_output, attention_output)
389
- return layer_output
390
-
391
-
392
- class EsmEncoder(nn.Module):
393
- def __init__(self, config):
394
- super().__init__()
395
- self.config = config
396
- self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
397
- self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
398
- self.gradient_checkpointing = False
399
-
400
- def forward(
401
- self,
402
- hidden_states: torch.Tensor,
403
- attention_mask: Optional[torch.FloatTensor] = None,
404
- output_hidden_states: bool = False,
405
- output_attentions: bool = False,
406
- ) -> BaseModelOutputWithPastAndCrossAttentions:
407
- """Forward pass for transformer encoder.
408
-
409
- Args:
410
- hidden_states: Input tensor
411
- attention_mask: Optional attention mask
412
- output_hidden_states: Whether to return all hidden states
413
- output_attentions: Whether to return attention weights
414
-
415
- Returns:
416
- BaseModelOutputWithPastAndCrossAttentions containing model outputs
417
- """
418
- all_hidden_states = () if output_hidden_states else None
419
- all_attentions = () if output_attentions else None
420
-
421
- for layer_module in self.layer:
422
- if output_hidden_states:
423
- all_hidden_states = all_hidden_states + (hidden_states,)
424
-
425
- if self.gradient_checkpointing and self.training:
426
- layer_outputs = self._gradient_checkpointing_func(
427
- layer_module.__call__,
428
- hidden_states,
429
- attention_mask,
430
- output_attentions,
431
- )
432
- else:
433
- layer_outputs = layer_module(
434
- hidden_states,
435
- attention_mask,
436
- output_attentions,
437
- )
438
-
439
- if output_attentions:
440
- hidden_states, attention_weights = layer_outputs
441
- all_attentions = all_attentions + (attention_weights,)
442
- else:
443
- hidden_states = layer_outputs
444
-
445
- if self.emb_layer_norm_after:
446
- hidden_states = self.emb_layer_norm_after(hidden_states)
447
-
448
- if output_hidden_states:
449
- all_hidden_states = all_hidden_states + (hidden_states,)
450
-
451
- return BaseModelOutputWithPastAndCrossAttentions(
452
- last_hidden_state=hidden_states,
453
- hidden_states=all_hidden_states,
454
- attentions=all_attentions,
455
- )
456
-
457
-
458
- ### Dataset for Embedding
459
- class ProteinDataset(Dataset):
460
- """Simple dataset for protein sequences."""
461
- def __init__(self, sequences: list[str]):
462
- self.sequences = sequences
463
-
464
- def __len__(self) -> int:
465
- return len(self.sequences)
466
-
467
- def __getitem__(self, idx: int) -> str:
468
- return self.sequences[idx]
469
-
470
-
471
- class FastEsmPreTrainedModel(PreTrainedModel):
472
- """
473
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
474
- models.
475
- """
476
- config_class = FastEsmConfig
477
- base_model_prefix = "fastesm"
478
- supports_gradient_checkpointing = True
479
- tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
480
- def _init_weights(self, module):
481
- """Initialize the weights"""
482
- if isinstance(module, nn.Linear):
483
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484
- if module.bias is not None:
485
- module.bias.data.zero_()
486
- elif isinstance(module, nn.Embedding):
487
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
488
- if module.padding_idx is not None:
489
- module.weight.data[module.padding_idx].zero_()
490
- elif isinstance(module, nn.LayerNorm):
491
- module.bias.data.zero_()
492
- module.weight.data.fill_(1.0)
493
-
494
- def get_input_embeddings(self) -> nn.Module:
495
- try:
496
- return self.embeddings.word_embeddings
497
- except AttributeError:
498
- return self.esm.embeddings.word_embeddings
499
-
500
- @property
501
- def device(self) -> torch.device:
502
- """Get the device of the model."""
503
- return next(self.parameters()).device
504
-
505
- def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
506
- """Apply mean pooling to sequence outputs."""
507
- if attention_mask is None:
508
- return x.mean(dim=1)
509
- else:
510
- attention_mask = attention_mask.unsqueeze(-1)
511
- return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
512
-
513
- def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
514
- """Collate function for batching sequences."""
515
- return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
516
-
517
- def _read_sequences_from_db(self, db_path: str) -> set[str]:
518
- """Read sequences from SQLite database."""
519
- import sqlite3
520
- sequences = []
521
- with sqlite3.connect(db_path) as conn:
522
- c = conn.cursor()
523
- c.execute("SELECT sequence FROM embeddings")
524
- while True:
525
- row = c.fetchone()
526
- if row is None:
527
- break
528
- sequences.append(row[0])
529
- return set(sequences)
530
-
531
- def embed_dataset(
532
- self,
533
- sequences: list[str],
534
- batch_size: int = 2,
535
- max_len: int = 512,
536
- full_embeddings: bool = False,
537
- full_precision: bool = False,
538
- pooling_type: str = 'mean',
539
- num_workers: int = 0,
540
- sql: bool = False,
541
- sql_db_path: str = 'embeddings.db',
542
- ) -> Optional[dict[str, torch.Tensor]]:
543
- """Embed a dataset of protein sequences.
544
-
545
- Args:
546
- sequences: List of protein sequences
547
- batch_size: Batch size for processing
548
- max_len: Maximum sequence length
549
- full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
550
- full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
551
- pooling_type: Type of pooling ('mean' or 'cls')
552
- num_workers: Number of workers for data loading, 0 for the main process
553
- sql: Whether to store embeddings in SQLite database - will be stored in float32
554
- sql_db_path: Path to SQLite database
555
-
556
- Returns:
557
- Dictionary mapping sequences to embeddings, or None if sql=True
558
- """
559
- sequences = list(set([seq[:max_len] for seq in sequences]))
560
- sequences = sorted(sequences, key=len, reverse=True)
561
- dataset = ProteinDataset(sequences)
562
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
563
- device = self.device
564
-
565
- def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
566
- if full_embeddings:
567
- return residue_embeddings
568
- elif pooling_type == 'mean':
569
- return self.mean_pooling(residue_embeddings, attention_mask)
570
- else:
571
- return residue_embeddings[:, 0, :]
572
-
573
- if sql:
574
- import sqlite3
575
- conn = sqlite3.connect(sql_db_path)
576
- c = conn.cursor()
577
- c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
578
- already_embedded = self._read_sequences_from_db(sql_db_path)
579
- to_embed = [seq for seq in sequences if seq not in already_embedded]
580
- print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
581
- print(f"Embedding {len(to_embed)} new sequences")
582
- if len(to_embed) > 0:
583
- with torch.no_grad():
584
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
585
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
586
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
587
- residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
588
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
589
-
590
- for seq, emb in zip(seqs, embeddings):
591
- c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
592
- (seq, emb.cpu().numpy().tobytes()))
593
-
594
- if (i + 1) % 100 == 0:
595
- conn.commit()
596
-
597
- conn.commit()
598
- conn.close()
599
- return None
600
-
601
- embeddings_dict = {}
602
- with torch.no_grad():
603
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
604
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
605
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
606
- residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float()
607
- if full_precision:
608
- residue_embeddings = residue_embeddings.float()
609
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
610
- for seq, emb in zip(seqs, embeddings):
611
- embeddings_dict[seq] = emb
612
-
613
- return embeddings_dict
614
-
615
-
616
- class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
617
- def __init__(self, config, add_pooling_layer=True):
618
- super().__init__(config)
619
- self.config = config
620
- self.embeddings = EsmEmbeddings(config)
621
- self.encoder = EsmEncoder(config)
622
- # Initialize weights and apply final processing
623
- self.post_init()
624
-
625
- def get_input_embeddings(self):
626
- return self.embeddings.word_embeddings
627
-
628
- def set_input_embeddings(self, value):
629
- self.embeddings.word_embeddings = value
630
-
631
- def forward(
632
- self,
633
- input_ids: Optional[torch.LongTensor] = None,
634
- attention_mask: Optional[torch.Tensor] = None,
635
- position_ids: Optional[torch.LongTensor] = None,
636
- inputs_embeds: Optional[torch.FloatTensor] = None,
637
- output_attentions: Optional[bool] = None,
638
- output_hidden_states: Optional[bool] = None,
639
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
640
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
641
- """Forward pass for base model.
642
-
643
- Args:
644
- input_ids: Input token IDs
645
- attention_mask: Optional attention mask
646
- position_ids: Optional position IDs
647
- inputs_embeds: Optional input embeddings
648
- output_hidden_states: Whether to return all hidden states
649
- output_attentions: Whether to return attention weights
650
-
651
- Returns:
652
- Model outputs including hidden states and optionally attention weights
653
- """
654
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
655
- output_hidden_states = (
656
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657
- )
658
-
659
- if input_ids is not None and inputs_embeds is not None:
660
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
661
- elif input_ids is not None:
662
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
663
- input_shape = input_ids.size()
664
- elif inputs_embeds is not None:
665
- input_shape = inputs_embeds.size()[:-1]
666
- else:
667
- raise ValueError("You have to specify either input_ids or inputs_embeds")
668
-
669
- batch_size, seq_length = input_shape
670
- embedding_output = self.embeddings(
671
- input_ids=input_ids,
672
- position_ids=position_ids,
673
- attention_mask=attention_mask,
674
- inputs_embeds=inputs_embeds,
675
- )
676
-
677
- if attention_mask is not None:
678
- extended_attention_mask = attention_mask[:, None, None, :].expand(
679
- batch_size, 1, seq_length, seq_length
680
- ).bool()
681
- else:
682
- extended_attention_mask = None
683
-
684
- encoder_outputs = self.encoder(
685
- embedding_output,
686
- attention_mask=extended_attention_mask,
687
- output_hidden_states=output_hidden_states,
688
- output_attentions=output_attentions,
689
- )
690
- sequence_output = encoder_outputs.last_hidden_state
691
-
692
- return BaseModelOutputWithPoolingAndCrossAttentions(
693
- last_hidden_state=sequence_output,
694
- hidden_states=encoder_outputs.hidden_states,
695
- attentions=encoder_outputs.attentions,
696
- )
697
-
698
-
699
- class FastEsmModel(FastEsmPreTrainedModel):
700
- def __init__(self, config, add_pooling_layer=True):
701
- super().__init__(config)
702
- self.config = config
703
- self.esm = FAST_ESM_ENCODER(config)
704
- self.pooler = EsmPooler(config) if add_pooling_layer else None
705
- # Initialize weights and apply final processing
706
- self.post_init()
707
-
708
- def get_input_embeddings(self):
709
- return self.embeddings.word_embeddings
710
-
711
- def set_input_embeddings(self, value):
712
- self.embeddings.word_embeddings = value
713
-
714
- def forward(
715
- self,
716
- input_ids: Optional[torch.LongTensor] = None,
717
- attention_mask: Optional[torch.Tensor] = None,
718
- position_ids: Optional[torch.LongTensor] = None,
719
- inputs_embeds: Optional[torch.FloatTensor] = None,
720
- output_attentions: Optional[bool] = None,
721
- output_hidden_states: Optional[bool] = None,
722
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
723
- ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
724
- """Forward pass for base model.
725
-
726
- Args:
727
- input_ids: Input token IDs
728
- attention_mask: Optional attention mask
729
- position_ids: Optional position IDs
730
- inputs_embeds: Optional input embeddings
731
- output_hidden_states: Whether to return all hidden states
732
- output_attentions: Whether to return attention weights
733
-
734
- Returns:
735
- Model outputs including hidden states and optionally attention weights
736
- """
737
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
738
- output_hidden_states = (
739
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
740
- )
741
-
742
- if input_ids is not None and inputs_embeds is not None:
743
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
744
- elif input_ids is not None:
745
- self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
746
- input_shape = input_ids.size()
747
- elif inputs_embeds is not None:
748
- input_shape = inputs_embeds.size()[:-1]
749
- else:
750
- raise ValueError("You have to specify either input_ids or inputs_embeds")
751
-
752
- batch_size, seq_length = input_shape
753
- embedding_output = self.embeddings(
754
- input_ids=input_ids,
755
- position_ids=position_ids,
756
- attention_mask=attention_mask,
757
- inputs_embeds=inputs_embeds,
758
- )
759
-
760
- if attention_mask is not None:
761
- extended_attention_mask = attention_mask[:, None, None, :].expand(
762
- batch_size, 1, seq_length, seq_length
763
- ).bool()
764
- else:
765
- extended_attention_mask = None
766
-
767
- encoder_outputs = self.encoder(
768
- embedding_output,
769
- attention_mask=extended_attention_mask,
770
- output_hidden_states=output_hidden_states,
771
- output_attentions=output_attentions,
772
- )
773
- sequence_output = encoder_outputs.last_hidden_state
774
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
775
-
776
- return BaseModelOutputWithPoolingAndCrossAttentions(
777
- last_hidden_state=sequence_output,
778
- pooler_output=pooled_output,
779
- hidden_states=encoder_outputs.hidden_states,
780
- attentions=encoder_outputs.attentions,
781
- )
782
-
783
-
784
- class FastEsmForMaskedLM(FastEsmPreTrainedModel):
785
- _tied_weights_keys = ["lm_head.decoder.weight"]
786
-
787
- def __init__(self, config):
788
- super().__init__(config)
789
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
790
- self.lm_head = EsmLMHead(config)
791
- self.loss_fct = nn.CrossEntropyLoss()
792
- self.init_weights()
793
-
794
- def get_output_embeddings(self):
795
- return self.lm_head.decoder
796
-
797
- def set_output_embeddings(self, new_embeddings):
798
- self.lm_head.decoder = new_embeddings
799
-
800
- def forward(
801
- self,
802
- input_ids: Optional[torch.LongTensor] = None,
803
- attention_mask: Optional[torch.Tensor] = None,
804
- position_ids: Optional[torch.LongTensor] = None,
805
- inputs_embeds: Optional[torch.FloatTensor] = None,
806
- labels: Optional[torch.LongTensor] = None,
807
- output_attentions: Optional[bool] = None,
808
- output_hidden_states: Optional[bool] = None,
809
- return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
810
- ) -> Union[Tuple, MaskedLMOutput]:
811
- outputs = self.esm(
812
- input_ids,
813
- attention_mask=attention_mask,
814
- position_ids=position_ids,
815
- inputs_embeds=inputs_embeds,
816
- output_hidden_states=output_hidden_states,
817
- output_attentions=output_attentions,
818
- )
819
- sequence_output = outputs.last_hidden_state
820
- prediction_scores = self.lm_head(sequence_output)
821
-
822
- loss = None
823
- if labels is not None:
824
- labels = labels.to(prediction_scores.device)
825
- loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
826
-
827
- return MaskedLMOutput(
828
- loss=loss,
829
- logits=prediction_scores,
830
- hidden_states=outputs.hidden_states,
831
- attentions=outputs.attentions,
832
- )
833
-
834
- def predict_contacts(self, tokens, attention_mask):
835
- raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
836
-
837
-
838
- class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
839
- def __init__(self, config):
840
- super().__init__(config)
841
- self.num_labels = config.num_labels
842
- self.config = config
843
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
844
- self.classifier = EsmClassificationHead(config)
845
- self.mse = nn.MSELoss()
846
- self.ce = nn.CrossEntropyLoss()
847
- self.bce = nn.BCEWithLogitsLoss()
848
- self.init_weights()
849
-
850
- def forward(
851
- self,
852
- input_ids: Optional[torch.LongTensor] = None,
853
- attention_mask: Optional[torch.Tensor] = None,
854
- position_ids: Optional[torch.LongTensor] = None,
855
- inputs_embeds: Optional[torch.FloatTensor] = None,
856
- labels: Optional[torch.LongTensor] = None,
857
- output_attentions: Optional[bool] = None,
858
- output_hidden_states: Optional[bool] = None,
859
- ) -> Union[Tuple, SequenceClassifierOutput]:
860
- outputs = self.esm(
861
- input_ids,
862
- attention_mask=attention_mask,
863
- position_ids=position_ids,
864
- inputs_embeds=inputs_embeds,
865
- output_attentions=output_attentions,
866
- output_hidden_states=output_hidden_states,
867
- )
868
- sequence_output = outputs.last_hidden_state
869
- logits = self.classifier(sequence_output)
870
-
871
- loss = None
872
- if labels is not None:
873
- labels = labels.to(logits.device)
874
- if self.config.problem_type is None:
875
- if self.num_labels == 1:
876
- self.config.problem_type = "regression"
877
- elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
878
- self.config.problem_type = "single_label_classification"
879
- else:
880
- self.config.problem_type = "multi_label_classification"
881
-
882
- if self.config.problem_type == "regression":
883
- if self.num_labels == 1:
884
- loss = self.mse(logits.squeeze(), labels.squeeze())
885
- else:
886
- loss = self.mse(logits, labels)
887
- elif self.config.problem_type == "single_label_classification":
888
- loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
889
- elif self.config.problem_type == "multi_label_classification":
890
- loss = self.bce(logits, labels)
891
-
892
- return SequenceClassifierOutput(
893
- loss=loss,
894
- logits=logits,
895
- hidden_states=outputs.hidden_states,
896
- attentions=outputs.attentions,
897
- )
898
-
899
-
900
- class FastEsmForTokenClassification(FastEsmPreTrainedModel):
901
- def __init__(self, config):
902
- super().__init__(config)
903
- self.num_labels = config.num_labels
904
- self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
905
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
906
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
907
- self.loss_fct = nn.CrossEntropyLoss()
908
- self.init_weights()
909
-
910
- def forward(
911
- self,
912
- input_ids: Optional[torch.LongTensor] = None,
913
- attention_mask: Optional[torch.Tensor] = None,
914
- position_ids: Optional[torch.LongTensor] = None,
915
- inputs_embeds: Optional[torch.FloatTensor] = None,
916
- labels: Optional[torch.LongTensor] = None,
917
- output_attentions: Optional[bool] = None,
918
- output_hidden_states: Optional[bool] = None,
919
- ) -> Union[Tuple, TokenClassifierOutput]:
920
- outputs = self.esm(
921
- input_ids,
922
- attention_mask=attention_mask,
923
- position_ids=position_ids,
924
- inputs_embeds=inputs_embeds,
925
- output_attentions=output_attentions,
926
- output_hidden_states=output_hidden_states,
927
- )
928
- sequence_output = outputs.last_hidden_state
929
- sequence_output = self.dropout(sequence_output)
930
- logits = self.classifier(sequence_output)
931
-
932
- loss = None
933
- if labels is not None:
934
- labels = labels.to(logits.device)
935
- loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
936
-
937
- return TokenClassifierOutput(
938
- loss=loss,
939
- logits=logits,
940
- hidden_states=outputs.hidden_states,
941
- attentions=outputs.attentions,
942
- )
943
-
944
-
945
- if __name__ == "__main__":
946
- """
947
- Test the hidden state differences between the FastEsmModel and the HF EsmModel.
948
- In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
949
- In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
950
- """
951
- import random
952
- from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
953
-
954
- model_paths = [
955
- "facebook/esm2_t6_8M_UR50D",
956
- "facebook/esm2_t12_35M_UR50D",
957
- #"facebook/esm2_t30_150M_UR50D",
958
- #"facebook/esm2_t33_650M_UR50D",
959
- ]
960
- canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
961
- length = 64
962
- seq_count = 100
963
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
964
- tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
965
-
966
- def generate_random_sequence(length: int) -> str:
967
- return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
968
-
969
- print("Percentage of hidden states that are within the tolerance:")
970
- for model_path in model_paths:
971
- print(f"Testing {model_path}...")
972
- tokenizer = EsmTokenizer.from_pretrained(model_path)
973
- config = FastEsmConfig.from_pretrained(model_path)
974
- fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
975
- model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
976
-
977
- counts = [0] * len(tolerances)
978
- for _ in range(seq_count):
979
- example_seq = generate_random_sequence(length)
980
- fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
981
- fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
982
-
983
- model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
984
- model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
985
-
986
- for i, atol in enumerate(tolerances):
987
- if torch.allclose(fast_output, model_output, atol=atol):
988
- counts[i] += 1
989
-
990
- print(f"{model_path}:")
991
- for i, atol in enumerate(tolerances):
992
- print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
993
-
994
- model.cpu()
995
- fast_model.cpu()
996
- del model
997
- del fast_model
998
- torch.cuda.empty_cache()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn import functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from typing import Optional, Tuple, Union
6
+ from einops import rearrange
7
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
8
+ from transformers.modeling_outputs import (
9
+ MaskedLMOutput,
10
+ BaseModelOutputWithPastAndCrossAttentions,
11
+ BaseModelOutputWithPoolingAndCrossAttentions,
12
+ SequenceClassifierOutput,
13
+ TokenClassifierOutput
14
+ )
15
+ from transformers.models.esm.modeling_esm import (
16
+ EsmIntermediate,
17
+ EsmOutput,
18
+ EsmPooler,
19
+ EsmLMHead,
20
+ EsmSelfOutput,
21
+ EsmClassificationHead,
22
+ )
23
+ from tqdm.auto import tqdm
24
+
25
+
26
+ class FastEsmConfig(PretrainedConfig):
27
+ model_type = "fast_esm"
28
+ def __init__(
29
+ self,
30
+ vocab_size=None,
31
+ mask_token_id=None,
32
+ pad_token_id=None,
33
+ hidden_size=768,
34
+ num_hidden_layers=12,
35
+ num_attention_heads=12,
36
+ intermediate_size=3072,
37
+ hidden_dropout_prob=0.1,
38
+ attention_probs_dropout_prob=0.1,
39
+ max_position_embeddings=1026,
40
+ initializer_range=0.02,
41
+ layer_norm_eps=1e-12,
42
+ position_embedding_type="absolute",
43
+ emb_layer_norm_before=None,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
47
+
48
+ self.vocab_size = vocab_size
49
+ self.hidden_size = hidden_size
50
+ self.num_hidden_layers = num_hidden_layers
51
+ self.num_attention_heads = num_attention_heads
52
+ self.intermediate_size = intermediate_size
53
+ self.hidden_dropout_prob = hidden_dropout_prob
54
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.initializer_range = initializer_range
57
+ self.layer_norm_eps = layer_norm_eps
58
+ self.position_embedding_type = position_embedding_type
59
+ self.emb_layer_norm_before = emb_layer_norm_before
60
+
61
+ def to_dict(self):
62
+ """
63
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
64
+
65
+ Returns:
66
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
67
+ """
68
+ output = super().to_dict()
69
+ return output
70
+
71
+
72
+ def rotate_half(x):
73
+ x1, x2 = x.chunk(2, dim=-1)
74
+ return torch.cat((-x2, x1), dim=-1)
75
+
76
+
77
+ def apply_rotary_pos_emb(x, cos, sin):
78
+ cos = cos[:, :, : x.shape[-2], :]
79
+ sin = sin[:, :, : x.shape[-2], :]
80
+
81
+ return (x * cos) + (rotate_half(x) * sin)
82
+
83
+
84
+ def symmetrize(x):
85
+ "Make layer symmetric in final two dimensions, used for contact prediction."
86
+ return x + x.transpose(-1, -2)
87
+
88
+
89
+ def average_product_correct(x):
90
+ "Perform average product correct, used for contact prediction."
91
+ a1 = x.sum(-1, keepdims=True)
92
+ a2 = x.sum(-2, keepdims=True)
93
+ a12 = x.sum((-1, -2), keepdims=True)
94
+
95
+ avg = a1 * a2
96
+ avg.div_(a12) # in-place to reduce memory
97
+ normalized = x - avg
98
+ return normalized
99
+
100
+
101
+ class EsmContactPredictionHead(nn.Module):
102
+ """Performs symmetrization, apc, and computes a logistic regression on the output features"""
103
+
104
+ def __init__(
105
+ self,
106
+ in_features: int,
107
+ bias=True,
108
+ eos_idx: int = 2,
109
+ ):
110
+ super().__init__()
111
+ self.in_features = in_features
112
+ self.eos_idx = eos_idx
113
+ self.regression = nn.Linear(in_features, 1, bias)
114
+ self.activation = nn.Sigmoid()
115
+
116
+ def forward(self, tokens, attentions):
117
+ # remove eos token attentions
118
+ eos_mask = tokens.ne(self.eos_idx).to(attentions)
119
+ eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
120
+ attentions = attentions * eos_mask[:, None, None, :, :]
121
+ attentions = attentions[..., :-1, :-1]
122
+ # remove cls token attentions
123
+ attentions = attentions[..., 1:, 1:]
124
+ batch_size, layers, heads, seqlen, _ = attentions.size()
125
+ attentions = attentions.view(batch_size, layers * heads, seqlen, seqlen)
126
+
127
+ # features: batch x channels x tokens x tokens (symmetric)
128
+ attentions = attentions.to(
129
+ self.regression.weight.device
130
+ ) # attentions always float32, may need to convert to float16
131
+ attentions = average_product_correct(symmetrize(attentions))
132
+ attentions = attentions.permute(0, 2, 3, 1)
133
+ return self.activation(self.regression(attentions).squeeze(3))
134
+
135
+
136
+ class RotaryEmbedding(torch.nn.Module):
137
+ """
138
+ Rotary position embeddings based on those in
139
+ [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
140
+ matrices which depend on their relative positions.
141
+ """
142
+
143
+ def __init__(self, dim: int):
144
+ super().__init__()
145
+ # Generate and save the inverse frequency buffer (non trainable)
146
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
147
+ inv_freq = inv_freq
148
+ self.register_buffer("inv_freq", inv_freq)
149
+
150
+ self._seq_len_cached = None
151
+ self._cos_cached = None
152
+ self._sin_cached = None
153
+
154
+ def _update_cos_sin_tables(self, x, seq_dimension=2):
155
+ seq_len = x.shape[seq_dimension]
156
+
157
+ # Reset the tables if the sequence length has changed,
158
+ # or if we're on a new device (possibly due to tracing for instance)
159
+ if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
160
+ self._seq_len_cached = seq_len
161
+ t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
162
+ freqs = torch.outer(t, self.inv_freq)
163
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
164
+
165
+ self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype)
166
+ self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype)
167
+
168
+ return self._cos_cached, self._sin_cached
169
+
170
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
171
+ self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
172
+
173
+ return (
174
+ apply_rotary_pos_emb(q, self._cos_cached, self._sin_cached),
175
+ apply_rotary_pos_emb(k, self._cos_cached, self._sin_cached),
176
+ )
177
+
178
+
179
+ class EsmEmbeddings(nn.Module):
180
+ """
181
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
182
+ """
183
+
184
+ def __init__(self, config):
185
+ super().__init__()
186
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
187
+ if config.emb_layer_norm_before:
188
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
189
+ else:
190
+ self.layer_norm = None
191
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
192
+ self.register_buffer(
193
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
194
+ )
195
+
196
+ def forward(
197
+ self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
198
+ ):
199
+ if inputs_embeds is None:
200
+ inputs_embeds = self.word_embeddings(input_ids)
201
+
202
+ embeddings = inputs_embeds
203
+
204
+ if self.layer_norm is not None:
205
+ embeddings = self.layer_norm(embeddings)
206
+ if attention_mask is not None:
207
+ embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
208
+ return embeddings
209
+
210
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
211
+ """
212
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
213
+
214
+ Args:
215
+ inputs_embeds: torch.Tensor
216
+
217
+ Returns: torch.Tensor
218
+ """
219
+ input_shape = inputs_embeds.size()[:-1]
220
+ sequence_length = input_shape[1]
221
+
222
+ position_ids = torch.arange(
223
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
224
+ )
225
+ return position_ids.unsqueeze(0).expand(input_shape)
226
+
227
+
228
+ class EsmSelfAttention(nn.Module):
229
+ def __init__(self, config, position_embedding_type=None):
230
+ super().__init__()
231
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
232
+ raise ValueError(
233
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
234
+ f"heads ({config.num_attention_heads})"
235
+ )
236
+
237
+ self.num_attention_heads = config.num_attention_heads
238
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
239
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
240
+
241
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
242
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
243
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
244
+ self.scale = self.attention_head_size**-0.5
245
+
246
+ self.dropout_prob = config.attention_probs_dropout_prob
247
+ self.position_embedding_type = position_embedding_type or getattr(
248
+ config, "position_embedding_type", "absolute"
249
+ )
250
+ self.rotary_embeddings = None
251
+ if self.position_embedding_type == "rotary":
252
+ self.rotary_embeddings = RotaryEmbedding(dim=self.attention_head_size)
253
+
254
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
255
+ return rearrange(x, 'b s (h d) -> b h s d', h=self.num_attention_heads)
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.Tensor,
260
+ attention_mask: Optional[torch.FloatTensor] = None,
261
+ output_attentions: bool = False,
262
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
263
+ """Forward pass for self attention.
264
+
265
+ Args:
266
+ hidden_states: Input tensor
267
+ attention_mask: Optional attention mask
268
+ output_attentions: Whether to return attention weights
269
+
270
+ Returns:
271
+ Output tensor and optionally attention weights
272
+ """
273
+ query_layer = self.transpose_for_scores(self.query(hidden_states)) * self.scale
274
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
275
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
276
+
277
+ if self.position_embedding_type == "rotary":
278
+ query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
279
+
280
+ if output_attentions:
281
+ # Manual attention computation to get attention weights
282
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
283
+ if attention_mask is not None:
284
+ attention_scores = attention_scores + attention_mask
285
+ attention_probs = F.softmax(attention_scores, dim=-1)
286
+ if self.dropout_prob > 0:
287
+ attention_probs = F.dropout(attention_probs, p=self.dropout_prob, training=self.training)
288
+ context_layer = torch.matmul(attention_probs, value_layer)
289
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
290
+ return context_layer, attention_probs
291
+ else:
292
+ context_layer = F.scaled_dot_product_attention(
293
+ query_layer,
294
+ key_layer,
295
+ value_layer,
296
+ attn_mask=attention_mask,
297
+ dropout_p=self.dropout_prob,
298
+ scale=1.0
299
+ )
300
+ context_layer = rearrange(context_layer, 'b h s d -> b s (h d)')
301
+ return context_layer
302
+
303
+
304
+ class EsmAttention(nn.Module):
305
+ def __init__(self, config):
306
+ super().__init__()
307
+ self.self = EsmSelfAttention(config)
308
+ self.output = EsmSelfOutput(config)
309
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
310
+
311
+ def forward(
312
+ self,
313
+ hidden_states: torch.Tensor,
314
+ attention_mask: Optional[torch.FloatTensor] = None,
315
+ output_attentions: bool = False,
316
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
317
+ """Forward pass for attention layer.
318
+
319
+ Args:
320
+ hidden_states: Input tensor
321
+ attention_mask: Optional attention mask
322
+ output_attentions: Whether to return attention weights
323
+
324
+ Returns:
325
+ Output tensor and optionally attention weights
326
+ """
327
+ hidden_states_ln = self.LayerNorm(hidden_states)
328
+ self_outputs = self.self(
329
+ hidden_states_ln,
330
+ attention_mask,
331
+ output_attentions,
332
+ )
333
+ if output_attentions:
334
+ attention_output, attention_weights = self_outputs
335
+ attention_output = self.output(attention_output, hidden_states)
336
+ return attention_output, attention_weights
337
+ else:
338
+ attention_output = self_outputs
339
+ return self.output(attention_output, hidden_states)
340
+
341
+
342
+ class EsmLayer(nn.Module):
343
+ def __init__(self, config):
344
+ super().__init__()
345
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
346
+ self.seq_len_dim = 1
347
+ self.attention = EsmAttention(config)
348
+ self.intermediate = EsmIntermediate(config)
349
+ self.output = EsmOutput(config)
350
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
351
+
352
+ def forward(
353
+ self,
354
+ hidden_states: torch.Tensor,
355
+ attention_mask: Optional[torch.FloatTensor] = None,
356
+ output_attentions: bool = False,
357
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
358
+ """Forward pass for transformer layer.
359
+
360
+ Args:
361
+ hidden_states: Input tensor
362
+ attention_mask: Optional attention mask
363
+ output_attentions: Whether to return attention weights
364
+
365
+ Returns:
366
+ Output tensor and optionally attention weights
367
+ """
368
+ attention_outputs = self.attention(
369
+ hidden_states,
370
+ attention_mask,
371
+ output_attentions,
372
+ )
373
+ if output_attentions:
374
+ attention_output, attention_weights = attention_outputs
375
+ else:
376
+ attention_output = attention_outputs
377
+ attention_weights = None
378
+
379
+ layer_output = self.feed_forward_chunk(attention_output)
380
+
381
+ if output_attentions:
382
+ return layer_output, attention_weights
383
+ return layer_output
384
+
385
+ def feed_forward_chunk(self, attention_output):
386
+ attention_output_ln = self.LayerNorm(attention_output)
387
+ intermediate_output = self.intermediate(attention_output_ln)
388
+ layer_output = self.output(intermediate_output, attention_output)
389
+ return layer_output
390
+
391
+
392
+ class EsmEncoder(nn.Module):
393
+ def __init__(self, config):
394
+ super().__init__()
395
+ self.config = config
396
+ self.layer = nn.ModuleList([EsmLayer(config) for _ in range(config.num_hidden_layers)])
397
+ self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
398
+ self.gradient_checkpointing = False
399
+
400
+ def forward(
401
+ self,
402
+ hidden_states: torch.Tensor,
403
+ attention_mask: Optional[torch.FloatTensor] = None,
404
+ output_hidden_states: bool = False,
405
+ output_attentions: bool = False,
406
+ ) -> BaseModelOutputWithPastAndCrossAttentions:
407
+ """Forward pass for transformer encoder.
408
+
409
+ Args:
410
+ hidden_states: Input tensor
411
+ attention_mask: Optional attention mask
412
+ output_hidden_states: Whether to return all hidden states
413
+ output_attentions: Whether to return attention weights
414
+
415
+ Returns:
416
+ BaseModelOutputWithPastAndCrossAttentions containing model outputs
417
+ """
418
+ all_hidden_states = () if output_hidden_states else None
419
+ all_attentions = () if output_attentions else None
420
+
421
+ for layer_module in self.layer:
422
+ if output_hidden_states:
423
+ all_hidden_states = all_hidden_states + (hidden_states,)
424
+
425
+ if self.gradient_checkpointing and self.training:
426
+ layer_outputs = self._gradient_checkpointing_func(
427
+ layer_module.__call__,
428
+ hidden_states,
429
+ attention_mask,
430
+ output_attentions,
431
+ )
432
+ else:
433
+ layer_outputs = layer_module(
434
+ hidden_states,
435
+ attention_mask,
436
+ output_attentions,
437
+ )
438
+
439
+ if output_attentions:
440
+ hidden_states, attention_weights = layer_outputs
441
+ all_attentions = all_attentions + (attention_weights,)
442
+ else:
443
+ hidden_states = layer_outputs
444
+
445
+ if self.emb_layer_norm_after:
446
+ hidden_states = self.emb_layer_norm_after(hidden_states)
447
+
448
+ if output_hidden_states:
449
+ all_hidden_states = all_hidden_states + (hidden_states,)
450
+
451
+ return BaseModelOutputWithPastAndCrossAttentions(
452
+ last_hidden_state=hidden_states,
453
+ hidden_states=all_hidden_states,
454
+ attentions=all_attentions,
455
+ )
456
+
457
+
458
+ ### Dataset for Embedding
459
+ class ProteinDataset(Dataset):
460
+ """Simple dataset for protein sequences."""
461
+ def __init__(self, sequences: list[str]):
462
+ self.sequences = sequences
463
+
464
+ def __len__(self) -> int:
465
+ return len(self.sequences)
466
+
467
+ def __getitem__(self, idx: int) -> str:
468
+ return self.sequences[idx]
469
+
470
+
471
+ class FastEsmPreTrainedModel(PreTrainedModel):
472
+ """
473
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
474
+ models.
475
+ """
476
+ config_class = FastEsmConfig
477
+ base_model_prefix = "fastesm"
478
+ supports_gradient_checkpointing = True
479
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
480
+ def _init_weights(self, module):
481
+ """Initialize the weights"""
482
+ if isinstance(module, nn.Linear):
483
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
484
+ if module.bias is not None:
485
+ module.bias.data.zero_()
486
+ elif isinstance(module, nn.Embedding):
487
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
488
+ if module.padding_idx is not None:
489
+ module.weight.data[module.padding_idx].zero_()
490
+ elif isinstance(module, nn.LayerNorm):
491
+ module.bias.data.zero_()
492
+ module.weight.data.fill_(1.0)
493
+
494
+ def get_input_embeddings(self) -> nn.Module:
495
+ try:
496
+ return self.embeddings.word_embeddings
497
+ except AttributeError:
498
+ return self.esm.embeddings.word_embeddings
499
+
500
+ @property
501
+ def device(self) -> torch.device:
502
+ """Get the device of the model."""
503
+ return next(self.parameters()).device
504
+
505
+ def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
506
+ """Apply mean pooling to sequence outputs."""
507
+ if attention_mask is None:
508
+ return x.mean(dim=1)
509
+ else:
510
+ attention_mask = attention_mask.unsqueeze(-1)
511
+ return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
512
+
513
+ def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
514
+ """Collate function for batching sequences."""
515
+ return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
516
+
517
+ def _read_sequences_from_db(self, db_path: str) -> set[str]:
518
+ """Read sequences from SQLite database."""
519
+ import sqlite3
520
+ sequences = []
521
+ with sqlite3.connect(db_path) as conn:
522
+ c = conn.cursor()
523
+ c.execute("SELECT sequence FROM embeddings")
524
+ while True:
525
+ row = c.fetchone()
526
+ if row is None:
527
+ break
528
+ sequences.append(row[0])
529
+ return set(sequences)
530
+
531
+ def embed_dataset(
532
+ self,
533
+ sequences: list[str],
534
+ batch_size: int = 2,
535
+ max_len: int = 512,
536
+ full_embeddings: bool = False,
537
+ full_precision: bool = False,
538
+ pooling_type: str = 'mean',
539
+ num_workers: int = 0,
540
+ sql: bool = False,
541
+ sql_db_path: str = 'embeddings.db',
542
+ ) -> Optional[dict[str, torch.Tensor]]:
543
+ """Embed a dataset of protein sequences.
544
+
545
+ Args:
546
+ sequences: List of protein sequences
547
+ batch_size: Batch size for processing
548
+ max_len: Maximum sequence length
549
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
550
+ full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
551
+ pooling_type: Type of pooling ('mean' or 'cls')
552
+ num_workers: Number of workers for data loading, 0 for the main process
553
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
554
+ sql_db_path: Path to SQLite database
555
+
556
+ Returns:
557
+ Dictionary mapping sequences to embeddings, or None if sql=True
558
+ """
559
+ sequences = list(set([seq[:max_len] for seq in sequences]))
560
+ sequences = sorted(sequences, key=len, reverse=True)
561
+ dataset = ProteinDataset(sequences)
562
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn)
563
+ device = self.device
564
+
565
+ def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
566
+ if full_embeddings:
567
+ return residue_embeddings
568
+ elif pooling_type == 'mean':
569
+ return self.mean_pooling(residue_embeddings, attention_mask)
570
+ else:
571
+ return residue_embeddings[:, 0, :]
572
+
573
+ if sql:
574
+ import sqlite3
575
+ conn = sqlite3.connect(sql_db_path)
576
+ c = conn.cursor()
577
+ c.execute('CREATE TABLE IF NOT EXISTS embeddings (sequence text PRIMARY KEY, embedding blob)')
578
+ already_embedded = self._read_sequences_from_db(sql_db_path)
579
+ to_embed = [seq for seq in sequences if seq not in already_embedded]
580
+ print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
581
+ print(f"Embedding {len(to_embed)} new sequences")
582
+ if len(to_embed) > 0:
583
+ with torch.no_grad():
584
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
585
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
586
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
587
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
588
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
589
+
590
+ for seq, emb in zip(seqs, embeddings):
591
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
592
+ (seq, emb.cpu().numpy().tobytes()))
593
+
594
+ if (i + 1) % 100 == 0:
595
+ conn.commit()
596
+
597
+ conn.commit()
598
+ conn.close()
599
+ return None
600
+
601
+ embeddings_dict = {}
602
+ with torch.no_grad():
603
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
604
+ seqs = sequences[i * batch_size:(i + 1) * batch_size]
605
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
606
+ residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float()
607
+ if full_precision:
608
+ residue_embeddings = residue_embeddings.float()
609
+ embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
610
+ for seq, emb in zip(seqs, embeddings):
611
+ embeddings_dict[seq] = emb
612
+
613
+ return embeddings_dict
614
+
615
+
616
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
617
+ def __init__(self, config, add_pooling_layer=True):
618
+ super().__init__(config)
619
+ self.config = config
620
+ self.embeddings = EsmEmbeddings(config)
621
+ self.encoder = EsmEncoder(config)
622
+ # Initialize weights and apply final processing
623
+ self.post_init()
624
+
625
+ def get_input_embeddings(self):
626
+ return self.embeddings.word_embeddings
627
+
628
+ def set_input_embeddings(self, value):
629
+ self.embeddings.word_embeddings = value
630
+
631
+ def forward(
632
+ self,
633
+ input_ids: Optional[torch.LongTensor] = None,
634
+ attention_mask: Optional[torch.Tensor] = None,
635
+ position_ids: Optional[torch.LongTensor] = None,
636
+ inputs_embeds: Optional[torch.FloatTensor] = None,
637
+ output_attentions: Optional[bool] = None,
638
+ output_hidden_states: Optional[bool] = None,
639
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
640
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
641
+ """Forward pass for base model.
642
+
643
+ Args:
644
+ input_ids: Input token IDs
645
+ attention_mask: Optional attention mask
646
+ position_ids: Optional position IDs
647
+ inputs_embeds: Optional input embeddings
648
+ output_hidden_states: Whether to return all hidden states
649
+ output_attentions: Whether to return attention weights
650
+
651
+ Returns:
652
+ Model outputs including hidden states and optionally attention weights
653
+ """
654
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
655
+ output_hidden_states = (
656
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
657
+ )
658
+
659
+ if input_ids is not None and inputs_embeds is not None:
660
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
661
+ elif input_ids is not None:
662
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
663
+ input_shape = input_ids.size()
664
+ elif inputs_embeds is not None:
665
+ input_shape = inputs_embeds.size()[:-1]
666
+ else:
667
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
668
+
669
+ batch_size, seq_length = input_shape
670
+ embedding_output = self.embeddings(
671
+ input_ids=input_ids,
672
+ position_ids=position_ids,
673
+ attention_mask=attention_mask,
674
+ inputs_embeds=inputs_embeds,
675
+ )
676
+
677
+ if attention_mask is not None:
678
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
679
+ batch_size, 1, seq_length, seq_length
680
+ ).bool()
681
+ else:
682
+ extended_attention_mask = None
683
+
684
+ encoder_outputs = self.encoder(
685
+ embedding_output,
686
+ attention_mask=extended_attention_mask,
687
+ output_hidden_states=output_hidden_states,
688
+ output_attentions=output_attentions,
689
+ )
690
+ sequence_output = encoder_outputs.last_hidden_state
691
+
692
+ return BaseModelOutputWithPoolingAndCrossAttentions(
693
+ last_hidden_state=sequence_output,
694
+ hidden_states=encoder_outputs.hidden_states,
695
+ attentions=encoder_outputs.attentions,
696
+ )
697
+
698
+
699
+ class FastEsmModel(FastEsmPreTrainedModel):
700
+ def __init__(self, config, add_pooling_layer=True):
701
+ super().__init__(config)
702
+ self.config = config
703
+ self.esm = FAST_ESM_ENCODER(config)
704
+ self.pooler = EsmPooler(config) if add_pooling_layer else None
705
+ # Initialize weights and apply final processing
706
+ self.post_init()
707
+
708
+ def get_input_embeddings(self):
709
+ return self.embeddings.word_embeddings
710
+
711
+ def set_input_embeddings(self, value):
712
+ self.embeddings.word_embeddings = value
713
+
714
+ def forward(
715
+ self,
716
+ input_ids: Optional[torch.LongTensor] = None,
717
+ attention_mask: Optional[torch.Tensor] = None,
718
+ position_ids: Optional[torch.LongTensor] = None,
719
+ inputs_embeds: Optional[torch.FloatTensor] = None,
720
+ output_attentions: Optional[bool] = None,
721
+ output_hidden_states: Optional[bool] = None,
722
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
723
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
724
+ """Forward pass for base model.
725
+
726
+ Args:
727
+ input_ids: Input token IDs
728
+ attention_mask: Optional attention mask
729
+ position_ids: Optional position IDs
730
+ inputs_embeds: Optional input embeddings
731
+ output_hidden_states: Whether to return all hidden states
732
+ output_attentions: Whether to return attention weights
733
+
734
+ Returns:
735
+ Model outputs including hidden states and optionally attention weights
736
+ """
737
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
738
+ output_hidden_states = (
739
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
740
+ )
741
+
742
+ if input_ids is not None and inputs_embeds is not None:
743
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
744
+ elif input_ids is not None:
745
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
746
+ input_shape = input_ids.size()
747
+ elif inputs_embeds is not None:
748
+ input_shape = inputs_embeds.size()[:-1]
749
+ else:
750
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
751
+
752
+ batch_size, seq_length = input_shape
753
+ embedding_output = self.embeddings(
754
+ input_ids=input_ids,
755
+ position_ids=position_ids,
756
+ attention_mask=attention_mask,
757
+ inputs_embeds=inputs_embeds,
758
+ )
759
+
760
+ if attention_mask is not None:
761
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
762
+ batch_size, 1, seq_length, seq_length
763
+ ).bool()
764
+ else:
765
+ extended_attention_mask = None
766
+
767
+ encoder_outputs = self.encoder(
768
+ embedding_output,
769
+ attention_mask=extended_attention_mask,
770
+ output_hidden_states=output_hidden_states,
771
+ output_attentions=output_attentions,
772
+ )
773
+ sequence_output = encoder_outputs.last_hidden_state
774
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
775
+
776
+ return BaseModelOutputWithPoolingAndCrossAttentions(
777
+ last_hidden_state=sequence_output,
778
+ pooler_output=pooled_output,
779
+ hidden_states=encoder_outputs.hidden_states,
780
+ attentions=encoder_outputs.attentions,
781
+ )
782
+
783
+
784
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
785
+ _tied_weights_keys = ["lm_head.decoder.weight"]
786
+
787
+ def __init__(self, config):
788
+ super().__init__(config)
789
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
790
+ self.lm_head = EsmLMHead(config)
791
+ self.loss_fct = nn.CrossEntropyLoss()
792
+ self.init_weights()
793
+
794
+ def get_output_embeddings(self):
795
+ return self.lm_head.decoder
796
+
797
+ def set_output_embeddings(self, new_embeddings):
798
+ self.lm_head.decoder = new_embeddings
799
+
800
+ def forward(
801
+ self,
802
+ input_ids: Optional[torch.LongTensor] = None,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ position_ids: Optional[torch.LongTensor] = None,
805
+ inputs_embeds: Optional[torch.FloatTensor] = None,
806
+ labels: Optional[torch.LongTensor] = None,
807
+ output_attentions: Optional[bool] = None,
808
+ output_hidden_states: Optional[bool] = None,
809
+ return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
810
+ ) -> Union[Tuple, MaskedLMOutput]:
811
+ outputs = self.esm(
812
+ input_ids,
813
+ attention_mask=attention_mask,
814
+ position_ids=position_ids,
815
+ inputs_embeds=inputs_embeds,
816
+ output_hidden_states=output_hidden_states,
817
+ output_attentions=output_attentions,
818
+ )
819
+ sequence_output = outputs.last_hidden_state
820
+ prediction_scores = self.lm_head(sequence_output)
821
+
822
+ loss = None
823
+ if labels is not None:
824
+ labels = labels.to(prediction_scores.device)
825
+ loss = self.loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
826
+
827
+ return MaskedLMOutput(
828
+ loss=loss,
829
+ logits=prediction_scores,
830
+ hidden_states=outputs.hidden_states,
831
+ attentions=outputs.attentions,
832
+ )
833
+
834
+ def predict_contacts(self, tokens, attention_mask):
835
+ raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
836
+
837
+
838
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
839
+ def __init__(self, config):
840
+ super().__init__(config)
841
+ self.num_labels = config.num_labels
842
+ self.config = config
843
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
844
+ self.classifier = EsmClassificationHead(config)
845
+ self.mse = nn.MSELoss()
846
+ self.ce = nn.CrossEntropyLoss()
847
+ self.bce = nn.BCEWithLogitsLoss()
848
+ self.init_weights()
849
+
850
+ def forward(
851
+ self,
852
+ input_ids: Optional[torch.LongTensor] = None,
853
+ attention_mask: Optional[torch.Tensor] = None,
854
+ position_ids: Optional[torch.LongTensor] = None,
855
+ inputs_embeds: Optional[torch.FloatTensor] = None,
856
+ labels: Optional[torch.LongTensor] = None,
857
+ output_attentions: Optional[bool] = None,
858
+ output_hidden_states: Optional[bool] = None,
859
+ ) -> Union[Tuple, SequenceClassifierOutput]:
860
+ outputs = self.esm(
861
+ input_ids,
862
+ attention_mask=attention_mask,
863
+ position_ids=position_ids,
864
+ inputs_embeds=inputs_embeds,
865
+ output_attentions=output_attentions,
866
+ output_hidden_states=output_hidden_states,
867
+ )
868
+ sequence_output = outputs.last_hidden_state
869
+ logits = self.classifier(sequence_output)
870
+
871
+ loss = None
872
+ if labels is not None:
873
+ labels = labels.to(logits.device)
874
+ if self.config.problem_type is None:
875
+ if self.num_labels == 1:
876
+ self.config.problem_type = "regression"
877
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
878
+ self.config.problem_type = "single_label_classification"
879
+ else:
880
+ self.config.problem_type = "multi_label_classification"
881
+
882
+ if self.config.problem_type == "regression":
883
+ if self.num_labels == 1:
884
+ loss = self.mse(logits.squeeze(), labels.squeeze())
885
+ else:
886
+ loss = self.mse(logits, labels)
887
+ elif self.config.problem_type == "single_label_classification":
888
+ loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1))
889
+ elif self.config.problem_type == "multi_label_classification":
890
+ loss = self.bce(logits, labels)
891
+
892
+ return SequenceClassifierOutput(
893
+ loss=loss,
894
+ logits=logits,
895
+ hidden_states=outputs.hidden_states,
896
+ attentions=outputs.attentions,
897
+ )
898
+
899
+
900
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
901
+ def __init__(self, config):
902
+ super().__init__(config)
903
+ self.num_labels = config.num_labels
904
+ self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
905
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
906
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
907
+ self.loss_fct = nn.CrossEntropyLoss()
908
+ self.init_weights()
909
+
910
+ def forward(
911
+ self,
912
+ input_ids: Optional[torch.LongTensor] = None,
913
+ attention_mask: Optional[torch.Tensor] = None,
914
+ position_ids: Optional[torch.LongTensor] = None,
915
+ inputs_embeds: Optional[torch.FloatTensor] = None,
916
+ labels: Optional[torch.LongTensor] = None,
917
+ output_attentions: Optional[bool] = None,
918
+ output_hidden_states: Optional[bool] = None,
919
+ ) -> Union[Tuple, TokenClassifierOutput]:
920
+ outputs = self.esm(
921
+ input_ids,
922
+ attention_mask=attention_mask,
923
+ position_ids=position_ids,
924
+ inputs_embeds=inputs_embeds,
925
+ output_attentions=output_attentions,
926
+ output_hidden_states=output_hidden_states,
927
+ )
928
+ sequence_output = outputs.last_hidden_state
929
+ sequence_output = self.dropout(sequence_output)
930
+ logits = self.classifier(sequence_output)
931
+
932
+ loss = None
933
+ if labels is not None:
934
+ labels = labels.to(logits.device)
935
+ loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
936
+
937
+ return TokenClassifierOutput(
938
+ loss=loss,
939
+ logits=logits,
940
+ hidden_states=outputs.hidden_states,
941
+ attentions=outputs.attentions,
942
+ )
943
+
944
+
945
+ if __name__ == "__main__":
946
+ """
947
+ Test the hidden state differences between the FastEsmModel and the HF EsmModel.
948
+ In full precision, the differences are very very small, but nonzero due to floating point issues with F.scaled_dot_product_attention.
949
+ In Pytorch 2.5+ (and linux kernel), this implementation is very fast and uses less memory than the HF implementation.
950
+ """
951
+ import random
952
+ from transformers import EsmForMaskedLM as TransformersEsmModel, EsmTokenizer
953
+
954
+ model_paths = [
955
+ "facebook/esm2_t6_8M_UR50D",
956
+ "facebook/esm2_t12_35M_UR50D",
957
+ #"facebook/esm2_t30_150M_UR50D",
958
+ #"facebook/esm2_t33_650M_UR50D",
959
+ ]
960
+ canonical_amino_acids = "ACDEFGHIKLMNPQRSTVWY"
961
+ length = 64
962
+ seq_count = 100
963
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
964
+ tolerances = [1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
965
+
966
+ def generate_random_sequence(length: int) -> str:
967
+ return 'M' + "".join(random.choices(canonical_amino_acids, k=length))
968
+
969
+ print("Percentage of hidden states that are within the tolerance:")
970
+ for model_path in model_paths:
971
+ print(f"Testing {model_path}...")
972
+ tokenizer = EsmTokenizer.from_pretrained(model_path)
973
+ config = FastEsmConfig.from_pretrained(model_path)
974
+ fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
975
+ model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
976
+
977
+ counts = [0] * len(tolerances)
978
+ for _ in range(seq_count):
979
+ example_seq = generate_random_sequence(length)
980
+ fast_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
981
+ fast_output = fast_model(fast_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
982
+
983
+ model_tokens = tokenizer(example_seq, return_tensors="pt").input_ids.to(device)
984
+ model_output = model(model_tokens, output_hidden_states=True).hidden_states[-1].detach().cpu()
985
+
986
+ for i, atol in enumerate(tolerances):
987
+ if torch.allclose(fast_output, model_output, atol=atol):
988
+ counts[i] += 1
989
+
990
+ print(f"{model_path}:")
991
+ for i, atol in enumerate(tolerances):
992
+ print(f" tolerance={atol}: {counts[i] / seq_count * 100}%")
993
+
994
+ model.cpu()
995
+ fast_model.cpu()
996
+ del model
997
+ del fast_model
998
+ torch.cuda.empty_cache()