lhallee commited on
Commit
9966ab1
·
verified ·
1 Parent(s): 1bdd914

Upload modeling_fastesm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_fastesm.py +332 -152
modeling_fastesm.py CHANGED
@@ -1,11 +1,13 @@
1
  import torch
2
  import torch.nn as nn
 
3
  from torch.nn import functional as F
4
- from torch.utils.data import Dataset, DataLoader
5
- from typing import Optional, Tuple, Union
 
6
  from einops import rearrange
7
  from dataclasses import dataclass
8
- from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer
9
  from transformers.modeling_outputs import (
10
  ModelOutput,
11
  BaseModelOutputWithPastAndCrossAttentions,
@@ -26,31 +28,31 @@ from tqdm.auto import tqdm
26
 
27
  @dataclass
28
  class EsmMaskedLMOutput(ModelOutput):
29
- loss: Optional[torch.FloatTensor] = None
30
- logits: Optional[torch.FloatTensor] = None
31
- last_hidden_state: Optional[torch.FloatTensor] = None
32
- hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
33
- attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
34
 
35
 
36
  class FastEsmConfig(PretrainedConfig):
37
  model_type = "fast_esm"
38
  def __init__(
39
  self,
40
- vocab_size=None,
41
- mask_token_id=None,
42
- pad_token_id=None,
43
- hidden_size=768,
44
- num_hidden_layers=12,
45
- num_attention_heads=12,
46
- intermediate_size=3072,
47
- hidden_dropout_prob=0.1,
48
- attention_probs_dropout_prob=0.1,
49
- max_position_embeddings=1026,
50
- initializer_range=0.02,
51
- layer_norm_eps=1e-12,
52
- position_embedding_type="absolute",
53
- emb_layer_norm_before=None,
54
  **kwargs,
55
  ):
56
  super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
@@ -68,35 +70,35 @@ class FastEsmConfig(PretrainedConfig):
68
  self.position_embedding_type = position_embedding_type
69
  self.emb_layer_norm_before = emb_layer_norm_before
70
 
71
- def to_dict(self):
72
  """
73
  Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
74
 
75
  Returns:
76
- `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
77
  """
78
  output = super().to_dict()
79
  return output
80
 
81
 
82
- def rotate_half(x):
83
  x1, x2 = x.chunk(2, dim=-1)
84
  return torch.cat((-x2, x1), dim=-1)
85
 
86
 
87
- def apply_rotary_pos_emb(x, cos, sin):
88
  cos = cos[:, :, : x.shape[-2], :]
89
  sin = sin[:, :, : x.shape[-2], :]
90
 
91
  return (x * cos) + (rotate_half(x) * sin)
92
 
93
 
94
- def symmetrize(x):
95
  "Make layer symmetric in final two dimensions, used for contact prediction."
96
  return x + x.transpose(-1, -2)
97
 
98
 
99
- def average_product_correct(x):
100
  "Perform average product correct, used for contact prediction."
101
  a1 = x.sum(-1, keepdims=True)
102
  a2 = x.sum(-2, keepdims=True)
@@ -114,18 +116,18 @@ class EsmContactPredictionHead(nn.Module):
114
  def __init__(
115
  self,
116
  in_features: int,
117
- bias=True,
118
  eos_idx: int = 2,
119
  ):
120
  super().__init__()
121
  self.in_features = in_features
122
  self.eos_idx = eos_idx
123
- self.regression = nn.Linear(in_features, 1, bias)
124
  self.activation = nn.Sigmoid()
125
 
126
- def forward(self, tokens, attentions):
127
  # remove eos token attentions
128
- eos_mask = tokens.ne(self.eos_idx).to(attentions)
129
  eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
130
  attentions = attentions * eos_mask[:, None, None, :, :]
131
  attentions = attentions[..., :-1, :-1]
@@ -161,7 +163,7 @@ class RotaryEmbedding(torch.nn.Module):
161
  self._cos_cached = None
162
  self._sin_cached = None
163
 
164
- def _update_cos_sin_tables(self, x, seq_dimension=2):
165
  seq_len = x.shape[seq_dimension]
166
 
167
  # Reset the tables if the sequence length has changed,
@@ -204,7 +206,12 @@ class EsmEmbeddings(nn.Module):
204
  )
205
 
206
  def forward(
207
- self, input_ids=None, attention_mask=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
 
 
 
 
 
208
  ):
209
  if inputs_embeds is None:
210
  inputs_embeds = self.word_embeddings(input_ids)
@@ -236,7 +243,7 @@ class EsmEmbeddings(nn.Module):
236
 
237
 
238
  class EsmSelfAttention(nn.Module):
239
- def __init__(self, config, position_embedding_type=None):
240
  super().__init__()
241
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
242
  raise ValueError(
@@ -267,8 +274,8 @@ class EsmSelfAttention(nn.Module):
267
  def forward(
268
  self,
269
  hidden_states: torch.Tensor,
270
- attention_mask: Optional[torch.FloatTensor] = None,
271
- output_attentions: bool = False,
272
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
273
  """Forward pass for self attention.
274
 
@@ -321,8 +328,8 @@ class EsmAttention(nn.Module):
321
  def forward(
322
  self,
323
  hidden_states: torch.Tensor,
324
- attention_mask: Optional[torch.FloatTensor] = None,
325
- output_attentions: bool = False,
326
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
327
  """Forward pass for attention layer.
328
 
@@ -362,8 +369,8 @@ class EsmLayer(nn.Module):
362
  def forward(
363
  self,
364
  hidden_states: torch.Tensor,
365
- attention_mask: Optional[torch.FloatTensor] = None,
366
- output_attentions: bool = False,
367
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
368
  """Forward pass for transformer layer.
369
 
@@ -410,9 +417,9 @@ class EsmEncoder(nn.Module):
410
  def forward(
411
  self,
412
  hidden_states: torch.Tensor,
413
- attention_mask: Optional[torch.FloatTensor] = None,
414
- output_hidden_states: bool = False,
415
- output_attentions: bool = False,
416
  ) -> BaseModelOutputWithPastAndCrossAttentions:
417
  """Forward pass for transformer encoder.
418
 
@@ -465,8 +472,90 @@ class EsmEncoder(nn.Module):
465
  )
466
 
467
 
468
- ### Dataset for Embedding
469
- class ProteinDataset(Dataset):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
  """Simple dataset for protein sequences."""
471
  def __init__(self, sequences: list[str]):
472
  self.sequences = sequences
@@ -478,52 +567,22 @@ class ProteinDataset(Dataset):
478
  return self.sequences[idx]
479
 
480
 
481
- class FastEsmPreTrainedModel(PreTrainedModel):
482
- """
483
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
484
- models.
485
- """
486
- config_class = FastEsmConfig
487
- base_model_prefix = "fastesm"
488
- supports_gradient_checkpointing = True
489
- tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
490
- def _init_weights(self, module):
491
- """Initialize the weights"""
492
- if isinstance(module, nn.Linear):
493
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
494
- if module.bias is not None:
495
- module.bias.data.zero_()
496
- elif isinstance(module, nn.Embedding):
497
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
498
- if module.padding_idx is not None:
499
- module.weight.data[module.padding_idx].zero_()
500
- elif isinstance(module, nn.LayerNorm):
501
- module.bias.data.zero_()
502
- module.weight.data.fill_(1.0)
503
 
504
- def get_input_embeddings(self) -> nn.Module:
505
- try:
506
- return self.embeddings.word_embeddings
507
- except AttributeError:
508
- return self.esm.embeddings.word_embeddings
509
 
510
  @property
511
  def device(self) -> torch.device:
512
  """Get the device of the model."""
513
  return next(self.parameters()).device
514
 
515
- def mean_pooling(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
516
- """Apply mean pooling to sequence outputs."""
517
- if attention_mask is None:
518
- return x.mean(dim=1)
519
- else:
520
- attention_mask = attention_mask.unsqueeze(-1)
521
- return (x * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
522
-
523
- def _collate_fn(self, sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
524
- """Collate function for batching sequences."""
525
- return self.tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
526
-
527
  def _read_sequences_from_db(self, db_path: str) -> set[str]:
528
  """Read sequences from SQLite database."""
529
  import sqlite3
@@ -540,15 +599,18 @@ class FastEsmPreTrainedModel(PreTrainedModel):
540
 
541
  def embed_dataset(
542
  self,
543
- sequences: list[str],
 
544
  batch_size: int = 2,
545
  max_len: int = 512,
546
  full_embeddings: bool = False,
547
- full_precision: bool = False,
548
- pooling_type: str = 'mean',
549
  num_workers: int = 0,
550
  sql: bool = False,
 
551
  sql_db_path: str = 'embeddings.db',
 
552
  ) -> Optional[dict[str, torch.Tensor]]:
553
  """Embed a dataset of protein sequences.
554
 
@@ -557,7 +619,6 @@ class FastEsmPreTrainedModel(PreTrainedModel):
557
  batch_size: Batch size for processing
558
  max_len: Maximum sequence length
559
  full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
560
- full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
561
  pooling_type: Type of pooling ('mean' or 'cls')
562
  num_workers: Number of workers for data loading, 0 for the main process
563
  sql: Whether to store embeddings in SQLite database - will be stored in float32
@@ -565,18 +626,46 @@ class FastEsmPreTrainedModel(PreTrainedModel):
565
 
566
  Returns:
567
  Dictionary mapping sequences to embeddings, or None if sql=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  """
 
 
 
569
  device = self.device
 
570
 
571
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
572
- if full_embeddings:
573
  return residue_embeddings
574
- elif pooling_type == 'mean':
575
- return self.mean_pooling(residue_embeddings, attention_mask)
576
  else:
577
- return residue_embeddings[:, 0, :]
578
 
579
- sequences = list(set([seq[:max_len] for seq in sequences]))
580
  if sql:
581
  import sqlite3
582
  conn = sqlite3.connect(sql_db_path)
@@ -587,52 +676,94 @@ class FastEsmPreTrainedModel(PreTrainedModel):
587
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
588
  print(f"Embedding {len(to_embed)} new sequences")
589
  if len(to_embed) > 0:
590
- to_embed = sorted(to_embed, key=len, reverse=True)
591
  dataset = ProteinDataset(to_embed)
592
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
593
  with torch.no_grad():
594
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
595
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
596
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
597
- residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float() # required for sql
598
  embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
599
-
600
- for seq, emb in zip(seqs, embeddings):
 
601
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
602
  (seq, emb.cpu().numpy().tobytes()))
603
 
604
  if (i + 1) % 100 == 0:
605
  conn.commit()
606
-
607
  conn.commit()
608
  conn.close()
609
  return None
610
-
611
- sequences = list(set([seq[:max_len] for seq in sequences]))
612
- sequences = sorted(sequences, key=len, reverse=True)
613
- dataset = ProteinDataset(sequences)
614
- dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate_fn, shuffle=False)
615
  embeddings_dict = {}
616
- with torch.no_grad():
617
- for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
618
- seqs = sequences[i * batch_size:(i + 1) * batch_size]
619
- input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
620
- residue_embeddings = self.forward(input_ids, attention_mask, output_hidden_states=True).hidden_states[-1].detach().float()
621
- if full_precision:
622
- residue_embeddings = residue_embeddings.float()
623
- embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
624
- for seq, emb in zip(seqs, embeddings):
625
- embeddings_dict[seq] = emb
626
-
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  return embeddings_dict
628
 
629
 
630
- class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
631
- def __init__(self, config, add_pooling_layer=True):
632
- super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  self.config = config
634
  self.embeddings = EsmEmbeddings(config)
635
  self.encoder = EsmEncoder(config)
 
 
 
636
  # Initialize weights and apply final processing
637
  self.post_init()
638
 
@@ -642,12 +773,36 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
642
  def set_input_embeddings(self, value):
643
  self.embeddings.word_embeddings = value
644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  def forward(
646
  self,
647
- input_ids: Optional[torch.LongTensor] = None,
648
  attention_mask: Optional[torch.Tensor] = None,
649
- position_ids: Optional[torch.LongTensor] = None,
650
- inputs_embeds: Optional[torch.FloatTensor] = None,
651
  output_attentions: Optional[bool] = None,
652
  output_hidden_states: Optional[bool] = None,
653
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
@@ -679,7 +834,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
679
  raise ValueError("You have to specify either input_ids or inputs_embeds")
680
 
681
  batch_size, seq_length = input_shape
682
- embedding_output = self.embeddings(
683
  input_ids=input_ids,
684
  position_ids=position_ids,
685
  attention_mask=attention_mask,
@@ -694,7 +849,7 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
694
  extended_attention_mask = None
695
 
696
  encoder_outputs = self.encoder(
697
- embedding_output,
698
  attention_mask=extended_attention_mask,
699
  output_hidden_states=output_hidden_states,
700
  output_attentions=output_attentions,
@@ -708,9 +863,9 @@ class FAST_ESM_ENCODER(FastEsmPreTrainedModel):
708
  )
709
 
710
 
711
- class FastEsmModel(FastEsmPreTrainedModel):
712
- def __init__(self, config, add_pooling_layer=True):
713
- super().__init__(config)
714
  self.config = config
715
  self.esm = FAST_ESM_ENCODER(config)
716
  self.pooler = EsmPooler(config) if add_pooling_layer else None
@@ -723,12 +878,18 @@ class FastEsmModel(FastEsmPreTrainedModel):
723
  def set_input_embeddings(self, value):
724
  self.embeddings.word_embeddings = value
725
 
 
 
 
 
 
 
726
  def forward(
727
  self,
728
- input_ids: Optional[torch.LongTensor] = None,
729
  attention_mask: Optional[torch.Tensor] = None,
730
- position_ids: Optional[torch.LongTensor] = None,
731
- inputs_embeds: Optional[torch.FloatTensor] = None,
732
  output_attentions: Optional[bool] = None,
733
  output_hidden_states: Optional[bool] = None,
734
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
@@ -778,11 +939,11 @@ class FastEsmModel(FastEsmPreTrainedModel):
778
  )
779
 
780
 
781
- class FastEsmForMaskedLM(FastEsmPreTrainedModel):
782
  _tied_weights_keys = ["lm_head.decoder.weight"]
783
 
784
  def __init__(self, config):
785
- super().__init__(config)
786
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
787
  self.lm_head = EsmLMHead(config)
788
  self.loss_fct = nn.CrossEntropyLoss()
@@ -794,13 +955,19 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
794
  def set_output_embeddings(self, new_embeddings):
795
  self.lm_head.decoder = new_embeddings
796
 
 
 
 
 
 
 
797
  def forward(
798
  self,
799
- input_ids: Optional[torch.LongTensor] = None,
800
  attention_mask: Optional[torch.Tensor] = None,
801
- position_ids: Optional[torch.LongTensor] = None,
802
- inputs_embeds: Optional[torch.FloatTensor] = None,
803
- labels: Optional[torch.LongTensor] = None,
804
  output_attentions: Optional[bool] = None,
805
  output_hidden_states: Optional[bool] = None,
806
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
@@ -829,13 +996,10 @@ class FastEsmForMaskedLM(FastEsmPreTrainedModel):
829
  attentions=outputs.attentions,
830
  )
831
 
832
- def predict_contacts(self, tokens, attention_mask):
833
- raise NotImplementedError("predict_contacts is not supported by F.scaled_dot_product_attention")
834
-
835
 
836
- class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
837
  def __init__(self, config):
838
- super().__init__(config)
839
  self.num_labels = config.num_labels
840
  self.config = config
841
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
@@ -845,13 +1009,19 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
845
  self.bce = nn.BCEWithLogitsLoss()
846
  self.init_weights()
847
 
 
 
 
 
 
 
848
  def forward(
849
  self,
850
- input_ids: Optional[torch.LongTensor] = None,
851
  attention_mask: Optional[torch.Tensor] = None,
852
- position_ids: Optional[torch.LongTensor] = None,
853
- inputs_embeds: Optional[torch.FloatTensor] = None,
854
- labels: Optional[torch.LongTensor] = None,
855
  output_attentions: Optional[bool] = None,
856
  output_hidden_states: Optional[bool] = None,
857
  return_dict: Optional[bool] = None
@@ -896,9 +1066,9 @@ class FastEsmForSequenceClassification(FastEsmPreTrainedModel):
896
  )
897
 
898
 
899
- class FastEsmForTokenClassification(FastEsmPreTrainedModel):
900
  def __init__(self, config):
901
- super().__init__(config)
902
  self.num_labels = config.num_labels
903
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
904
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
@@ -906,13 +1076,19 @@ class FastEsmForTokenClassification(FastEsmPreTrainedModel):
906
  self.loss_fct = nn.CrossEntropyLoss()
907
  self.init_weights()
908
 
 
 
 
 
 
 
909
  def forward(
910
  self,
911
- input_ids: Optional[torch.LongTensor] = None,
912
  attention_mask: Optional[torch.Tensor] = None,
913
- position_ids: Optional[torch.LongTensor] = None,
914
- inputs_embeds: Optional[torch.FloatTensor] = None,
915
- labels: Optional[torch.LongTensor] = None,
916
  output_attentions: Optional[bool] = None,
917
  output_hidden_states: Optional[bool] = None,
918
  return_dict: Optional[bool] = None
@@ -972,7 +1148,11 @@ if __name__ == "__main__":
972
  tokenizer = EsmTokenizer.from_pretrained(model_path)
973
  config = FastEsmConfig.from_pretrained(model_path)
974
  fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
 
 
975
  model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
 
 
976
 
977
  counts = [0] * len(tolerances)
978
  for _ in range(seq_count):
 
1
  import torch
2
  import torch.nn as nn
3
+ import os
4
  from torch.nn import functional as F
5
+ from torch.utils.data import Dataset as TorchDataset
6
+ from torch.utils.data import DataLoader as DataLoader
7
+ from typing import Optional, Tuple, Union, Callable, List, Dict, Any
8
  from einops import rearrange
9
  from dataclasses import dataclass
10
+ from transformers import PreTrainedModel, PretrainedConfig, EsmTokenizer, PreTrainedTokenizerBase
11
  from transformers.modeling_outputs import (
12
  ModelOutput,
13
  BaseModelOutputWithPastAndCrossAttentions,
 
28
 
29
  @dataclass
30
  class EsmMaskedLMOutput(ModelOutput):
31
+ loss: Optional[torch.Tensor] = None
32
+ logits: Optional[torch.Tensor] = None
33
+ last_hidden_state: Optional[torch.Tensor] = None
34
+ hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
35
+ attentions: Optional[Tuple[torch.Tensor, ...]] = None
36
 
37
 
38
  class FastEsmConfig(PretrainedConfig):
39
  model_type = "fast_esm"
40
  def __init__(
41
  self,
42
+ vocab_size: int = None,
43
+ mask_token_id: int = None,
44
+ pad_token_id: int = None,
45
+ hidden_size: int = 768,
46
+ num_hidden_layers: int = 12,
47
+ num_attention_heads: int = 12,
48
+ intermediate_size: int = 3072,
49
+ hidden_dropout_prob: float = 0.1,
50
+ attention_probs_dropout_prob: float = 0.1,
51
+ max_position_embeddings: int = 1026,
52
+ initializer_range: float = 0.02,
53
+ layer_norm_eps: float = 1e-12,
54
+ position_embedding_type: str = "absolute",
55
+ emb_layer_norm_before: bool = None,
56
  **kwargs,
57
  ):
58
  super().__init__(pad_token_id=pad_token_id, mask_token_id=mask_token_id, **kwargs)
 
70
  self.position_embedding_type = position_embedding_type
71
  self.emb_layer_norm_before = emb_layer_norm_before
72
 
73
+ def to_dict(self) -> Dict[str, Any]:
74
  """
75
  Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
76
 
77
  Returns:
78
+ `Dict[str, any]`: Dictionar y of all the attributes that make up this configuration instance,
79
  """
80
  output = super().to_dict()
81
  return output
82
 
83
 
84
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
85
  x1, x2 = x.chunk(2, dim=-1)
86
  return torch.cat((-x2, x1), dim=-1)
87
 
88
 
89
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
90
  cos = cos[:, :, : x.shape[-2], :]
91
  sin = sin[:, :, : x.shape[-2], :]
92
 
93
  return (x * cos) + (rotate_half(x) * sin)
94
 
95
 
96
+ def symmetrize(x: torch.Tensor) -> torch.Tensor:
97
  "Make layer symmetric in final two dimensions, used for contact prediction."
98
  return x + x.transpose(-1, -2)
99
 
100
 
101
+ def average_product_correct(x: torch.Tensor) -> torch.Tensor:
102
  "Perform average product correct, used for contact prediction."
103
  a1 = x.sum(-1, keepdims=True)
104
  a2 = x.sum(-2, keepdims=True)
 
116
  def __init__(
117
  self,
118
  in_features: int,
119
+ bias: bool = True,
120
  eos_idx: int = 2,
121
  ):
122
  super().__init__()
123
  self.in_features = in_features
124
  self.eos_idx = eos_idx
125
+ self.regression = nn.Linear(in_features, 1, bias=bias)
126
  self.activation = nn.Sigmoid()
127
 
128
+ def forward(self, input_ids: torch.Tensor, attentions: torch.Tensor) -> torch.Tensor:
129
  # remove eos token attentions
130
+ eos_mask = input_ids.ne(self.eos_idx).to(attentions)
131
  eos_mask = eos_mask.unsqueeze(1) * eos_mask.unsqueeze(2)
132
  attentions = attentions * eos_mask[:, None, None, :, :]
133
  attentions = attentions[..., :-1, :-1]
 
163
  self._cos_cached = None
164
  self._sin_cached = None
165
 
166
+ def _update_cos_sin_tables(self, x: torch.Tensor, seq_dimension: int = 2) -> Tuple[torch.Tensor, torch.Tensor]:
167
  seq_len = x.shape[seq_dimension]
168
 
169
  # Reset the tables if the sequence length has changed,
 
206
  )
207
 
208
  def forward(
209
+ self,
210
+ input_ids: Optional[torch.Tensor] = None,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ position_ids: Optional[torch.Tensor] = None,
213
+ inputs_embeds: Optional[torch.Tensor] = None,
214
+ past_key_values_length: Optional[int] = 0,
215
  ):
216
  if inputs_embeds is None:
217
  inputs_embeds = self.word_embeddings(input_ids)
 
243
 
244
 
245
  class EsmSelfAttention(nn.Module):
246
+ def __init__(self, config, position_embedding_type: Optional[str] = None):
247
  super().__init__()
248
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
249
  raise ValueError(
 
274
  def forward(
275
  self,
276
  hidden_states: torch.Tensor,
277
+ attention_mask: Optional[torch.Tensor] = None,
278
+ output_attentions: Optional[bool] = False,
279
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
280
  """Forward pass for self attention.
281
 
 
328
  def forward(
329
  self,
330
  hidden_states: torch.Tensor,
331
+ attention_mask: Optional[torch.Tensor] = None,
332
+ output_attentions: Optional[bool] = False,
333
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
334
  """Forward pass for attention layer.
335
 
 
369
  def forward(
370
  self,
371
  hidden_states: torch.Tensor,
372
+ attention_mask: Optional[torch.Tensor] = None,
373
+ output_attentions: Optional[bool] = False,
374
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
375
  """Forward pass for transformer layer.
376
 
 
417
  def forward(
418
  self,
419
  hidden_states: torch.Tensor,
420
+ attention_mask: Optional[torch.Tensor] = None,
421
+ output_hidden_states: Optional[bool] = False,
422
+ output_attentions: Optional[bool] = False,
423
  ) -> BaseModelOutputWithPastAndCrossAttentions:
424
  """Forward pass for transformer encoder.
425
 
 
472
  )
473
 
474
 
475
+ ### Support for embedding datasets with low code
476
+ class Pooler:
477
+ def __init__(self, pooling_types: List[str]):
478
+ self.pooling_types = pooling_types
479
+ self.pooling_options = {
480
+ 'mean': self.mean_pooling,
481
+ 'max': self.max_pooling,
482
+ 'min': self.min_pooling,
483
+ 'norm': self.norm_pooling,
484
+ 'prod': self.prod_pooling,
485
+ 'median': self.median_pooling,
486
+ 'std': self.std_pooling,
487
+ 'var': self.var_pooling,
488
+ 'cls': self.cls_pooling,
489
+ }
490
+
491
+ def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
492
+ if attention_mask is None:
493
+ return emb.mean(dim=1)
494
+ else:
495
+ attention_mask = attention_mask.unsqueeze(-1)
496
+ return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)
497
+
498
+ def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
499
+ if attention_mask is None:
500
+ return emb.max(dim=1).values
501
+ else:
502
+ attention_mask = attention_mask.unsqueeze(-1)
503
+ return (emb * attention_mask).max(dim=1).values
504
+
505
+ def min_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
506
+ if attention_mask is None:
507
+ return emb.min(dim=1).values
508
+ else:
509
+ attention_mask = attention_mask.unsqueeze(-1)
510
+ return (emb * attention_mask).min(dim=1).values
511
+
512
+ def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
513
+ if attention_mask is None:
514
+ return emb.norm(dim=1, p=2)
515
+ else:
516
+ attention_mask = attention_mask.unsqueeze(-1)
517
+ return (emb * attention_mask).norm(dim=1, p=2)
518
+
519
+ def prod_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
520
+ length = emb.shape[1]
521
+ if attention_mask is None:
522
+ return emb.prod(dim=1) / length
523
+ else:
524
+ attention_mask = attention_mask.unsqueeze(-1)
525
+ return ((emb * attention_mask).prod(dim=1) / attention_mask.sum(dim=1)) / length
526
+
527
+ def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
528
+ if attention_mask is None:
529
+ return emb.median(dim=1).values
530
+ else:
531
+ attention_mask = attention_mask.unsqueeze(-1)
532
+ return (emb * attention_mask).median(dim=1).values
533
+
534
+ def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
535
+ if attention_mask is None:
536
+ return emb.std(dim=1)
537
+ else:
538
+ attention_mask = attention_mask.unsqueeze(-1)
539
+ return (emb * attention_mask).std(dim=1)
540
+
541
+ def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
542
+ if attention_mask is None:
543
+ return emb.var(dim=1)
544
+ else:
545
+ attention_mask = attention_mask.unsqueeze(-1)
546
+ return (emb * attention_mask).var(dim=1)
547
+
548
+ def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # (b, L, d) -> (b, d)
549
+ return emb[:, 0, :]
550
+
551
+ def __call__(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): # [mean, max]
552
+ final_emb = []
553
+ for pooling_type in self.pooling_types:
554
+ final_emb.append(self.pooling_options[pooling_type](emb, attention_mask)) # (b, d)
555
+ return torch.cat(final_emb, dim=-1) # (b, n_pooling_types * d)
556
+
557
+
558
+ class ProteinDataset(TorchDataset):
559
  """Simple dataset for protein sequences."""
560
  def __init__(self, sequences: list[str]):
561
  self.sequences = sequences
 
567
  return self.sequences[idx]
568
 
569
 
570
+ def build_collator(tokenizer) -> Callable[[list[str]], tuple[torch.Tensor, torch.Tensor]]:
571
+ def _collate_fn(sequences: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
572
+ """Collate function for batching sequences."""
573
+ return tokenizer(sequences, return_tensors="pt", padding='longest', pad_to_multiple_of=8)
574
+ return _collate_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
 
576
+
577
+ class EmbeddingMixin:
578
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
579
+ raise NotImplementedError
 
580
 
581
  @property
582
  def device(self) -> torch.device:
583
  """Get the device of the model."""
584
  return next(self.parameters()).device
585
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  def _read_sequences_from_db(self, db_path: str) -> set[str]:
587
  """Read sequences from SQLite database."""
588
  import sqlite3
 
599
 
600
  def embed_dataset(
601
  self,
602
+ sequences: List[str],
603
+ tokenizer: PreTrainedTokenizerBase,
604
  batch_size: int = 2,
605
  max_len: int = 512,
606
  full_embeddings: bool = False,
607
+ embed_dtype: torch.dtype = torch.float32,
608
+ pooling_types: List[str] = ['mean'],
609
  num_workers: int = 0,
610
  sql: bool = False,
611
+ save: bool = True,
612
  sql_db_path: str = 'embeddings.db',
613
+ save_path: str = 'embeddings.pth',
614
  ) -> Optional[dict[str, torch.Tensor]]:
615
  """Embed a dataset of protein sequences.
616
 
 
619
  batch_size: Batch size for processing
620
  max_len: Maximum sequence length
621
  full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
 
622
  pooling_type: Type of pooling ('mean' or 'cls')
623
  num_workers: Number of workers for data loading, 0 for the main process
624
  sql: Whether to store embeddings in SQLite database - will be stored in float32
 
626
 
627
  Returns:
628
  Dictionary mapping sequences to embeddings, or None if sql=True
629
+
630
+ Note:
631
+ - If sql=True, embeddings can only be stored in float32
632
+ - sql is ideal if you need to stream a very large dataset for training in real-time
633
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
634
+ - sql will be used if it is True and save is True or False
635
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
636
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
637
+
638
+ Example:
639
+ >>> embedder = EmbeddingMixin()
640
+ >>> embedding_dict = embedder.embed_dataset(
641
+ sequences=[
642
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
643
+ ],
644
+ batch_size=2, # adjust for your GPU memory
645
+ max_len=512, # adjust for your needs
646
+ full_embeddings=False, # if True, no pooling is performed
647
+ embed_dtype=torch.float32, # cast to what dtype you want
648
+ pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
649
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
650
+ sql=False, # if True, embeddings will be stored in SQLite database
651
+ sql_db_path='embeddings.db',
652
+ save=True, # if True, embeddings will be saved as a .pth file
653
+ save_path='embeddings.pth',
654
+ )
655
+ >>> # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
656
  """
657
+ sequences = list(set([seq[:max_len] for seq in sequences]))
658
+ sequences = sorted(sequences, key=len, reverse=True)
659
+ collate_fn = build_collator(tokenizer)
660
  device = self.device
661
+ pooler = Pooler(pooling_types) if not full_embeddings else None
662
 
663
  def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
664
+ if full_embeddings or residue_embeddings.ndim == 2: # if already pooled or want residue-wise embeddings
665
  return residue_embeddings
 
 
666
  else:
667
+ return pooler(residue_embeddings, attention_mask)
668
 
 
669
  if sql:
670
  import sqlite3
671
  conn = sqlite3.connect(sql_db_path)
 
676
  print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}")
677
  print(f"Embedding {len(to_embed)} new sequences")
678
  if len(to_embed) > 0:
 
679
  dataset = ProteinDataset(to_embed)
680
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
681
  with torch.no_grad():
682
  for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
683
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
684
  input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
685
+ residue_embeddings = self._embed(input_ids, attention_mask).float() # sql requires float32
686
  embeddings = get_embeddings(residue_embeddings, attention_mask).cpu()
687
+ for seq, emb, mask in zip(seqs, embeddings, attention_mask):
688
+ if full_embeddings:
689
+ emb = emb[mask.bool()]
690
  c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
691
  (seq, emb.cpu().numpy().tobytes()))
692
 
693
  if (i + 1) % 100 == 0:
694
  conn.commit()
695
+
696
  conn.commit()
697
  conn.close()
698
  return None
699
+
 
 
 
 
700
  embeddings_dict = {}
701
+ if os.path.exists(save_path):
702
+ embeddings_dict = torch.load(save_path, map_location='cpu', weights_only=True)
703
+ to_embed = [seq for seq in sequences if seq not in embeddings_dict]
704
+ print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}")
705
+ print(f"Embedding {len(to_embed)} new sequences")
706
+ else:
707
+ to_embed = sequences
708
+ print(f"Embedding {len(to_embed)} new sequences")
709
+
710
+ if len(to_embed) > 0:
711
+ dataset = ProteinDataset(to_embed)
712
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, shuffle=False)
713
+ with torch.no_grad():
714
+ for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
715
+ seqs = to_embed[i * batch_size:(i + 1) * batch_size]
716
+ input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
717
+ residue_embeddings = self._embed(input_ids, attention_mask)
718
+ embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype).cpu()
719
+ for seq, emb in zip(seqs, embeddings):
720
+ embeddings_dict[seq] = emb
721
+
722
+ if save:
723
+ torch.save(embeddings_dict, save_path)
724
+
725
  return embeddings_dict
726
 
727
 
728
+ class FastEsmPreTrainedModel(PreTrainedModel):
729
+ """
730
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
731
+ models.
732
+ """
733
+ config_class = FastEsmConfig
734
+ base_model_prefix = "fastesm"
735
+ supports_gradient_checkpointing = True
736
+ tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
737
+ def _init_weights(self, module):
738
+ """Initialize the weights"""
739
+ if isinstance(module, nn.Linear):
740
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
741
+ if module.bias is not None:
742
+ module.bias.data.zero_()
743
+ elif isinstance(module, nn.Embedding):
744
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
745
+ if module.padding_idx is not None:
746
+ module.weight.data[module.padding_idx].zero_()
747
+ elif isinstance(module, nn.LayerNorm):
748
+ module.bias.data.zero_()
749
+ module.weight.data.fill_(1.0)
750
+
751
+ def get_input_embeddings(self) -> nn.Module:
752
+ try:
753
+ return self.embeddings.word_embeddings
754
+ except AttributeError:
755
+ return self.esm.embeddings.word_embeddings
756
+
757
+
758
+ class FAST_ESM_ENCODER(FastEsmPreTrainedModel, EmbeddingMixin):
759
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True):
760
+ super(FastEsmPreTrainedModel, self).__init__(config)
761
  self.config = config
762
  self.embeddings = EsmEmbeddings(config)
763
  self.encoder = EsmEncoder(config)
764
+ self.contact_head = EsmContactPredictionHead(
765
+ in_features=config.num_hidden_layers * config.num_attention_heads, bias=True
766
+ )
767
  # Initialize weights and apply final processing
768
  self.post_init()
769
 
 
773
  def set_input_embeddings(self, value):
774
  self.embeddings.word_embeddings = value
775
 
776
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
777
+ token_embedding_output = self.embeddings(input_ids, attention_mask=attention_mask)
778
+ batch_size, seq_length = input_ids.shape
779
+ if attention_mask is not None:
780
+ extended_attention_mask = attention_mask[:, None, None, :].expand(
781
+ batch_size, 1, seq_length, seq_length
782
+ ).bool()
783
+ else:
784
+ extended_attention_mask = None
785
+ encoder_outputs = self.encoder(
786
+ token_embedding_output,
787
+ attention_mask=extended_attention_mask,
788
+ output_hidden_states=False,
789
+ output_attentions=False,
790
+ )
791
+ return encoder_outputs.last_hidden_state
792
+
793
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
794
+ attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions
795
+ attns = torch.stack(attns, dim=1)
796
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3)
797
+ attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
798
+ return self.contact_head(input_ids, attns)
799
+
800
  def forward(
801
  self,
802
+ input_ids: Optional[torch.Tensor] = None,
803
  attention_mask: Optional[torch.Tensor] = None,
804
+ position_ids: Optional[torch.Tensor] = None,
805
+ inputs_embeds: Optional[torch.Tensor] = None,
806
  output_attentions: Optional[bool] = None,
807
  output_hidden_states: Optional[bool] = None,
808
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
 
834
  raise ValueError("You have to specify either input_ids or inputs_embeds")
835
 
836
  batch_size, seq_length = input_shape
837
+ token_embedding_output = self.embeddings(
838
  input_ids=input_ids,
839
  position_ids=position_ids,
840
  attention_mask=attention_mask,
 
849
  extended_attention_mask = None
850
 
851
  encoder_outputs = self.encoder(
852
+ token_embedding_output,
853
  attention_mask=extended_attention_mask,
854
  output_hidden_states=output_hidden_states,
855
  output_attentions=output_attentions,
 
863
  )
864
 
865
 
866
+ class FastEsmModel(FastEsmPreTrainedModel, EmbeddingMixin):
867
+ def __init__(self, config, add_pooling_layer: Optional[bool] = True):
868
+ super(FastEsmPreTrainedModel, self).__init__(config)
869
  self.config = config
870
  self.esm = FAST_ESM_ENCODER(config)
871
  self.pooler = EsmPooler(config) if add_pooling_layer else None
 
878
  def set_input_embeddings(self, value):
879
  self.embeddings.word_embeddings = value
880
 
881
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
882
+ return self.esm._embed(input_ids, attention_mask)
883
+
884
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
885
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
886
+
887
  def forward(
888
  self,
889
+ input_ids: Optional[torch.Tensor] = None,
890
  attention_mask: Optional[torch.Tensor] = None,
891
+ position_ids: Optional[torch.Tensor] = None,
892
+ inputs_embeds: Optional[torch.Tensor] = None,
893
  output_attentions: Optional[bool] = None,
894
  output_hidden_states: Optional[bool] = None,
895
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
 
939
  )
940
 
941
 
942
+ class FastEsmForMaskedLM(FastEsmPreTrainedModel, EmbeddingMixin):
943
  _tied_weights_keys = ["lm_head.decoder.weight"]
944
 
945
  def __init__(self, config):
946
+ super(FastEsmPreTrainedModel, self).__init__(config)
947
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
948
  self.lm_head = EsmLMHead(config)
949
  self.loss_fct = nn.CrossEntropyLoss()
 
955
  def set_output_embeddings(self, new_embeddings):
956
  self.lm_head.decoder = new_embeddings
957
 
958
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
959
+ return self.esm._embed(input_ids, attention_mask)
960
+
961
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
962
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
963
+
964
  def forward(
965
  self,
966
+ input_ids: Optional[torch.Tensor] = None,
967
  attention_mask: Optional[torch.Tensor] = None,
968
+ position_ids: Optional[torch.Tensor] = None,
969
+ inputs_embeds: Optional[torch.Tensor] = None,
970
+ labels: Optional[torch.Tensor] = None,
971
  output_attentions: Optional[bool] = None,
972
  output_hidden_states: Optional[bool] = None,
973
  return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
 
996
  attentions=outputs.attentions,
997
  )
998
 
 
 
 
999
 
1000
+ class FastEsmForSequenceClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1001
  def __init__(self, config):
1002
+ super(FastEsmPreTrainedModel, self).__init__(config)
1003
  self.num_labels = config.num_labels
1004
  self.config = config
1005
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
 
1009
  self.bce = nn.BCEWithLogitsLoss()
1010
  self.init_weights()
1011
 
1012
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1013
+ return self.esm._embed(input_ids, attention_mask)
1014
+
1015
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
1016
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
1017
+
1018
  def forward(
1019
  self,
1020
+ input_ids: Optional[torch.Tensor] = None,
1021
  attention_mask: Optional[torch.Tensor] = None,
1022
+ position_ids: Optional[torch.Tensor] = None,
1023
+ inputs_embeds: Optional[torch.Tensor] = None,
1024
+ labels: Optional[torch.Tensor] = None,
1025
  output_attentions: Optional[bool] = None,
1026
  output_hidden_states: Optional[bool] = None,
1027
  return_dict: Optional[bool] = None
 
1066
  )
1067
 
1068
 
1069
+ class FastEsmForTokenClassification(FastEsmPreTrainedModel, EmbeddingMixin):
1070
  def __init__(self, config):
1071
+ super(FastEsmPreTrainedModel, self).__init__(config)
1072
  self.num_labels = config.num_labels
1073
  self.esm = FAST_ESM_ENCODER(config, add_pooling_layer=False)
1074
  self.dropout = nn.Dropout(config.hidden_dropout_prob)
 
1076
  self.loss_fct = nn.CrossEntropyLoss()
1077
  self.init_weights()
1078
 
1079
+ def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1080
+ return self.esm._embed(input_ids, attention_mask)
1081
+
1082
+ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
1083
+ return self.esm.predict_contacts(input_ids, attention_mask=attention_mask)
1084
+
1085
  def forward(
1086
  self,
1087
+ input_ids: Optional[torch.Tensor] = None,
1088
  attention_mask: Optional[torch.Tensor] = None,
1089
+ position_ids: Optional[torch.Tensor] = None,
1090
+ inputs_embeds: Optional[torch.Tensor] = None,
1091
+ labels: Optional[torch.Tensor] = None,
1092
  output_attentions: Optional[bool] = None,
1093
  output_hidden_states: Optional[bool] = None,
1094
  return_dict: Optional[bool] = None
 
1148
  tokenizer = EsmTokenizer.from_pretrained(model_path)
1149
  config = FastEsmConfig.from_pretrained(model_path)
1150
  fast_model = FastEsmForMaskedLM(config).from_pretrained(model_path).to(device)
1151
+ print('fast model')
1152
+ print(fast_model)
1153
  model = TransformersEsmModel.from_pretrained(model_path, token_dropout=False).to(device)
1154
+ print('transformers model')
1155
+ print(model)
1156
 
1157
  counts = [0] * len(tolerances)
1158
  for _ in range(seq_count):