lhallee commited on
Commit
b5d932c
·
verified ·
1 Parent(s): 4f950cb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -4
README.md CHANGED
@@ -4,12 +4,15 @@ tags: []
4
  ---
5
  # FastESM
6
 
7
- ## A faster half-precision version of ESM2-650 that leverages FlashAttenion2
8
 
9
- Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
 
 
10
 
11
  Outputting attentions and predicting contacts are not possible from SDPA. Various other optimizations also make the base implementation slightly different than the HF one.
12
 
 
13
  ```python
14
  import torch
15
  from transformers import AutoModel, AutoTokenizer
@@ -26,8 +29,43 @@ with torch.no_grad():
26
  print(embeddings.shape) # (1, 11, 1280)
27
  ```
28
 
29
- Because we trained in mixed-precision float16, float16 has closer outputs to the float32 weights then bfloat16.
30
- When summing the MSE of 1000 sequences vs. the float32 weights:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  Average MSE for FP16: 0.00000140
 
32
  Average MSE for BF16: 0.00004125
33
 
 
 
 
 
 
4
  ---
5
  # FastESM
6
 
7
+ ## A faster half-precision version of ESM2-650 that leverages FlashAttention2
8
 
9
+ FastESM is a fully Huggingface compatible version rewritten with a newer PyTorch Attention implementation which will run FlashAttention2 when possible.
10
+
11
+ To produce the FastESM weights, we trained ESM2-650 50000 additional steps in fp16 mixed precision on [OMG50](tattabio/OMG_prot50) up to sequence length of **2048**.
12
 
13
  Outputting attentions and predicting contacts are not possible from SDPA. Various other optimizations also make the base implementation slightly different than the HF one.
14
 
15
+ ## Use with 🤗 transformers
16
  ```python
17
  import torch
18
  from transformers import AutoModel, AutoTokenizer
 
29
  print(embeddings.shape) # (1, 11, 1280)
30
  ```
31
 
32
+ ## Embed entire datasets with no new code
33
+ To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the progress bar is usually much longer than the actual time.
34
+ ```python
35
+ embeddings = model.embed_dataset(
36
+ sequences=sequences, # list of protein strings
37
+ batch_size=16, # embedding batch size
38
+ max_len=2048, # truncate to max_len
39
+ full_embeddings=True, # return residue-wise embeddings
40
+ full_precision=False, # store as float32
41
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
42
+ num_workers=0, # data loading num workers
43
+ sql=False, # return dictionary of sequences and embeddings
44
+ )
45
+
46
+ _ = model.embed_dataset(
47
+ sequences=sequences, # list of protein strings
48
+ batch_size=16, # embedding batch size
49
+ max_len=2048, # truncate to max_len
50
+ full_embeddings=True, # return residue-wise embeddings
51
+ full_precision=False, # store as float32
52
+ pooling_type='mean', # use mean pooling if protein-wise embeddings
53
+ num_workers=0, # data loading num workers
54
+ sql=True, # store sequences in local SQL database
55
+ sql_db_path='embeddings.db', # path to .db file of choice
56
+ )
57
+ ```
58
+
59
+ ## Comparison of half precisions
60
+ Presumabely because we trained in mixed-precision fp16, fp16 has closer outputs to the fp32 weights then bf16. Therefore, we recommend loading in fp16.
61
+
62
+ When summing the MSE of 1000 sequences vs. the fp32 weights:
63
+
64
  Average MSE for FP16: 0.00000140
65
+
66
  Average MSE for BF16: 0.00004125
67
 
68
+ ### FlashAttention2
69
+ Requires PyTorch 2.5+ for the most savings, see [SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html).
70
+
71
+ ### Citation