lhallee commited on
Commit
22db99a
·
verified ·
1 Parent(s): 9966ab1

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +66 -46
README.md CHANGED
@@ -11,19 +11,30 @@ Load any ESM2 models into a FastEsm model to dramatically speed up training and
11
  Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
  Various other optimizations also make the base implementation slightly different than the one in transformers.
13
 
14
- # FastESM2-650
15
-
16
- ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
17
- To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
18
-
19
  ## Use with 🤗 transformers
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ### For working with embeddings
22
  ```python
23
  import torch
24
  from transformers import AutoModel, AutoTokenizer
25
 
26
- model_path = 'Synthyra/FastESM2_650'
27
  model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
28
  tokenizer = model.tokenizer
29
 
@@ -59,52 +70,61 @@ with torch.no_grad():
59
  print(attentions[-1].shape) # (2, 20, 11, 11)
60
  ```
61
 
 
 
 
 
 
 
 
62
 
63
  ## Embed entire datasets with no new code
64
- To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
65
- ```python
66
- embeddings = model.embed_dataset(
67
- sequences=sequences, # list of protein strings
68
- batch_size=16, # embedding batch size
69
- max_len=2048, # truncate to max_len
70
- full_embeddings=True, # return residue-wise embeddings
71
- full_precision=False, # store as float32
72
- pooling_type='mean', # use mean pooling if protein-wise embeddings
73
- num_workers=0, # data loading num workers
74
- sql=False, # return dictionary of sequences and embeddings
75
- )
76
 
77
- _ = model.embed_dataset(
78
- sequences=sequences, # list of protein strings
79
- batch_size=16, # embedding batch size
80
- max_len=2048, # truncate to max_len
81
- full_embeddings=True, # return residue-wise embeddings
82
- full_precision=False, # store as float32
83
- pooling_type='mean', # use mean pooling if protein-wise embeddings
84
- num_workers=0, # data loading num workers
85
- sql=True, # store sequences in local SQL database
86
- sql_db_path='embeddings.db', # path to .db file of choice
 
 
 
 
 
 
87
  )
 
88
  ```
89
 
90
- ## Model probes
91
- We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
92
-
93
- The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
94
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png)
95
-
96
- ## Comparison of half precisions
97
- Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
98
-
99
- When summing the MSE of 1000 sequences vs. the fp32 weights:
100
-
101
- Average MSE for FP16: 0.00000140
102
-
103
- Average MSE for BF16: 0.00004125
 
 
 
 
 
 
 
 
 
104
 
105
- ### Inference speed
106
- We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
107
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png)
108
 
109
  ### Citation
110
  If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
@@ -117,4 +137,4 @@ If you use any of this implementation or work please cite it (as well as the [ES
117
  doi = { 10.57967/hf/3729 },
118
  publisher = { Hugging Face }
119
  }
120
- ```
 
11
  Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
  Various other optimizations also make the base implementation slightly different than the one in transformers.
13
 
 
 
 
 
 
14
  ## Use with 🤗 transformers
15
 
16
+ ### Supported models
17
+ ```python
18
+ model_dict = {
19
+ # Synthyra/ESM2-8M
20
+ 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
21
+ # Synthyra/ESM2-35M
22
+ 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
23
+ # Synthyra/ESM2-150M
24
+ 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
25
+ # Synthyra/ESM2-650M
26
+ 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
27
+ # Synthyra/ESM2-3B
28
+ 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
29
+ }
30
+ ```
31
+
32
  ### For working with embeddings
33
  ```python
34
  import torch
35
  from transformers import AutoModel, AutoTokenizer
36
 
37
+ model_path = 'Synthyra/ESM2-8M'
38
  model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
39
  tokenizer = model.tokenizer
40
 
 
70
  print(attentions[-1].shape) # (2, 20, 11, 11)
71
  ```
72
 
73
+ ### Contact prediction
74
+ Because we can output attentions using the naive attention implementation, the contact prediction is also supported
75
+ ```python
76
+ with torch.no_grad():
77
+ contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
78
+ ```
79
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
80
 
81
  ## Embed entire datasets with no new code
82
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ Example:
85
+ ```python
86
+ embedding_dict = model.embed_dataset(
87
+ sequences=[
88
+ 'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
89
+ ],
90
+ batch_size=2, # adjust for your GPU memory
91
+ max_len=512, # adjust for your needs
92
+ full_embeddings=False, # if True, no pooling is performed
93
+ embed_dtype=torch.float32, # cast to what dtype you want
94
+ pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
95
+ num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
96
+ sql=False, # if True, embeddings will be stored in SQLite database
97
+ sql_db_path='embeddings.db',
98
+ save=True, # if True, embeddings will be saved as a .pth file
99
+ save_path='embeddings.pth',
100
  )
101
+ # embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
102
  ```
103
 
104
+ ```
105
+ model.embed_dataset()
106
+ Args:
107
+ sequences: List of protein sequences
108
+ batch_size: Batch size for processing
109
+ max_len: Maximum sequence length
110
+ full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
111
+ pooling_type: Type of pooling ('mean' or 'cls')
112
+ num_workers: Number of workers for data loading, 0 for the main process
113
+ sql: Whether to store embeddings in SQLite database - will be stored in float32
114
+ sql_db_path: Path to SQLite database
115
+
116
+ Returns:
117
+ Dictionary mapping sequences to embeddings, or None if sql=True
118
+
119
+ Note:
120
+ - If sql=True, embeddings can only be stored in float32
121
+ - sql is ideal if you need to stream a very large dataset for training in real-time
122
+ - save=True is ideal if you can store the entire embedding dictionary in RAM
123
+ - sql will be used if it is True and save is True or False
124
+ - If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
125
+ - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
126
+ ```
127
 
 
 
 
128
 
129
  ### Citation
130
  If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
 
137
  doi = { 10.57967/hf/3729 },
138
  publisher = { Hugging Face }
139
  }
140
+ ```