lhallee commited on
Commit
2b3b1f4
·
verified ·
1 Parent(s): 57660d5

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +25 -29
README.md CHANGED
@@ -3,9 +3,6 @@ library_name: transformers
3
  tags: []
4
  ---
5
 
6
- # NOTE
7
- There was previously a bug with Huggingface weight tieing that caused the logits of FastESM to differ from ESM2. That bug is now resolved.
8
-
9
  # FastESM
10
  FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
11
 
@@ -14,30 +11,19 @@ Load any ESM2 models into a FastEsm model to dramatically speed up training and
14
  Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
15
  Various other optimizations also make the base implementation slightly different than the one in transformers.
16
 
17
- ## Use with 🤗 transformers
18
 
19
- ### Supported models
20
- ```python
21
- model_dict = {
22
- # Synthyra/ESM2-8M
23
- 'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
24
- # Synthyra/ESM2-35M
25
- 'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
26
- # Synthyra/ESM2-150M
27
- 'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
28
- # Synthyra/ESM2-650M
29
- 'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
30
- # Synthyra/ESM2-3B
31
- 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
32
- }
33
- ```
34
 
35
  ### For working with embeddings
36
  ```python
37
  import torch
38
  from transformers import AutoModel, AutoTokenizer
39
 
40
- model_path = 'Synthyra/ESM2-8M'
41
  model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
42
  tokenizer = model.tokenizer
43
 
@@ -73,14 +59,6 @@ with torch.no_grad():
73
  print(attentions[-1].shape) # (2, 20, 11, 11)
74
  ```
75
 
76
- ### Contact prediction
77
- Because we can output attentions using the naive attention implementation, the contact prediction is also supported
78
- ```python
79
- with torch.no_grad():
80
- contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
81
- ```
82
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/9707OSXZ3Wdgn0Ni-55T-.png)
83
-
84
  ## Embed entire datasets with no new code
85
  To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
86
 
@@ -129,6 +107,24 @@ Note:
129
  - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
130
  ```
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  ### Citation
134
  If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
@@ -141,4 +137,4 @@ If you use any of this implementation or work please cite it (as well as the [ES
141
  doi = { 10.57967/hf/3729 },
142
  publisher = { Hugging Face }
143
  }
144
- ```
 
3
  tags: []
4
  ---
5
 
 
 
 
6
  # FastESM
7
  FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
 
 
11
  Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
12
  Various other optimizations also make the base implementation slightly different than the one in transformers.
13
 
14
+ # FastESM2-650
15
 
16
+ ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
17
+ To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
18
+
19
+ ## Use with 🤗 transformers
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  ### For working with embeddings
22
  ```python
23
  import torch
24
  from transformers import AutoModel, AutoTokenizer
25
 
26
+ model_path = 'Synthyra/FastESM2_650'
27
  model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
28
  tokenizer = model.tokenizer
29
 
 
59
  print(attentions[-1].shape) # (2, 20, 11, 11)
60
  ```
61
 
 
 
 
 
 
 
 
 
62
  ## Embed entire datasets with no new code
63
  To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
64
 
 
107
  - Sequences will be truncated to max_len and sorted by length in descending order for faster processing
108
  ```
109
 
110
+ ## Model probes
111
+ We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
112
+
113
+ The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.
114
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/d1Xi6k1Q4-9By_MtzTvdV.png)
115
+
116
+ ## Comparison of half precisions
117
+ Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
118
+
119
+ When summing the MSE of 1000 sequences vs. the fp32 weights:
120
+
121
+ Average MSE for FP16: 0.00000140
122
+
123
+ Average MSE for BF16: 0.00004125
124
+
125
+ ### Inference speed
126
+ We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
127
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/62f2bd3bdb7cbd214b658c48/PvaBGfuJXEW2v_WLkt63y.png)
128
 
129
  ### Citation
130
  If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
 
137
  doi = { 10.57967/hf/3729 },
138
  publisher = { Hugging Face }
139
  }
140
+ ```