Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
@@ -11,19 +11,30 @@ Load any ESM2 models into a FastEsm model to dramatically speed up training and
|
|
11 |
Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
|
12 |
Various other optimizations also make the base implementation slightly different than the one in transformers.
|
13 |
|
14 |
-
# FastESM2-650
|
15 |
-
|
16 |
-
## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
|
17 |
-
To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
|
18 |
-
|
19 |
## Use with 🤗 transformers
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
### For working with embeddings
|
22 |
```python
|
23 |
import torch
|
24 |
from transformers import AutoModel, AutoTokenizer
|
25 |
|
26 |
-
model_path = 'Synthyra/
|
27 |
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
|
28 |
tokenizer = model.tokenizer
|
29 |
|
@@ -59,52 +70,61 @@ with torch.no_grad():
|
|
59 |
print(attentions[-1].shape) # (2, 20, 11, 11)
|
60 |
```
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
## Embed entire datasets with no new code
|
64 |
-
To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
|
65 |
-
```python
|
66 |
-
embeddings = model.embed_dataset(
|
67 |
-
sequences=sequences, # list of protein strings
|
68 |
-
batch_size=16, # embedding batch size
|
69 |
-
max_len=2048, # truncate to max_len
|
70 |
-
full_embeddings=True, # return residue-wise embeddings
|
71 |
-
full_precision=False, # store as float32
|
72 |
-
pooling_type='mean', # use mean pooling if protein-wise embeddings
|
73 |
-
num_workers=0, # data loading num workers
|
74 |
-
sql=False, # return dictionary of sequences and embeddings
|
75 |
-
)
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
)
|
|
|
88 |
```
|
89 |
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
### Inference speed
|
106 |
-
We look at various ESM models and their throughput on an H100. FastESM is over twice as fast as ESM2-650 with longer sequences. Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
|
107 |
-

|
108 |
|
109 |
### Citation
|
110 |
If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
|
@@ -117,4 +137,4 @@ If you use any of this implementation or work please cite it (as well as the [ES
|
|
117 |
doi = { 10.57967/hf/3729 },
|
118 |
publisher = { Hugging Face }
|
119 |
}
|
120 |
-
```
|
|
|
11 |
Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
|
12 |
Various other optimizations also make the base implementation slightly different than the one in transformers.
|
13 |
|
|
|
|
|
|
|
|
|
|
|
14 |
## Use with 🤗 transformers
|
15 |
|
16 |
+
### Supported models
|
17 |
+
```python
|
18 |
+
model_dict = {
|
19 |
+
# Synthyra/ESM2-8M
|
20 |
+
'ESM2-8M': 'facebook/esm2_t6_8M_UR50D',
|
21 |
+
# Synthyra/ESM2-35M
|
22 |
+
'ESM2-35M': 'facebook/esm2_t12_35M_UR50D',
|
23 |
+
# Synthyra/ESM2-150M
|
24 |
+
'ESM2-150M': 'facebook/esm2_t30_150M_UR50D',
|
25 |
+
# Synthyra/ESM2-650M
|
26 |
+
'ESM2-650M': 'facebook/esm2_t33_650M_UR50D',
|
27 |
+
# Synthyra/ESM2-3B
|
28 |
+
'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
|
29 |
+
}
|
30 |
+
```
|
31 |
+
|
32 |
### For working with embeddings
|
33 |
```python
|
34 |
import torch
|
35 |
from transformers import AutoModel, AutoTokenizer
|
36 |
|
37 |
+
model_path = 'Synthyra/ESM2-8M'
|
38 |
model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
|
39 |
tokenizer = model.tokenizer
|
40 |
|
|
|
70 |
print(attentions[-1].shape) # (2, 20, 11, 11)
|
71 |
```
|
72 |
|
73 |
+
### Contact prediction
|
74 |
+
Because we can output attentions using the naive attention implementation, the contact prediction is also supported
|
75 |
+
```python
|
76 |
+
with torch.no_grad():
|
77 |
+
contact_map = model.predict_contacts(**tokenized).squeeze().cpu().numpy() # (seq_len, seq_len)
|
78 |
+
```
|
79 |
+

|
80 |
|
81 |
## Embed entire datasets with no new code
|
82 |
+
To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time it will take.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
84 |
+
Example:
|
85 |
+
```python
|
86 |
+
embedding_dict = model.embed_dataset(
|
87 |
+
sequences=[
|
88 |
+
'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
|
89 |
+
],
|
90 |
+
batch_size=2, # adjust for your GPU memory
|
91 |
+
max_len=512, # adjust for your needs
|
92 |
+
full_embeddings=False, # if True, no pooling is performed
|
93 |
+
embed_dtype=torch.float32, # cast to what dtype you want
|
94 |
+
pooling_type=['mean', 'cls'], # more than one pooling type will be concatenated together
|
95 |
+
num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
|
96 |
+
sql=False, # if True, embeddings will be stored in SQLite database
|
97 |
+
sql_db_path='embeddings.db',
|
98 |
+
save=True, # if True, embeddings will be saved as a .pth file
|
99 |
+
save_path='embeddings.pth',
|
100 |
)
|
101 |
+
# embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
|
102 |
```
|
103 |
|
104 |
+
```
|
105 |
+
model.embed_dataset()
|
106 |
+
Args:
|
107 |
+
sequences: List of protein sequences
|
108 |
+
batch_size: Batch size for processing
|
109 |
+
max_len: Maximum sequence length
|
110 |
+
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
111 |
+
pooling_type: Type of pooling ('mean' or 'cls')
|
112 |
+
num_workers: Number of workers for data loading, 0 for the main process
|
113 |
+
sql: Whether to store embeddings in SQLite database - will be stored in float32
|
114 |
+
sql_db_path: Path to SQLite database
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
Dictionary mapping sequences to embeddings, or None if sql=True
|
118 |
+
|
119 |
+
Note:
|
120 |
+
- If sql=True, embeddings can only be stored in float32
|
121 |
+
- sql is ideal if you need to stream a very large dataset for training in real-time
|
122 |
+
- save=True is ideal if you can store the entire embedding dictionary in RAM
|
123 |
+
- sql will be used if it is True and save is True or False
|
124 |
+
- If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
|
125 |
+
- Sequences will be truncated to max_len and sorted by length in descending order for faster processing
|
126 |
+
```
|
127 |
|
|
|
|
|
|
|
128 |
|
129 |
### Citation
|
130 |
If you use any of this implementation or work please cite it (as well as the [ESM2](https://www.science.org/doi/10.1126/science.ade2574) paper).
|
|
|
137 |
doi = { 10.57967/hf/3729 },
|
138 |
publisher = { Hugging Face }
|
139 |
}
|
140 |
+
```
|