lhallee commited on
Commit
33bcbb8
·
verified ·
1 Parent(s): 45ab4a0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +46 -10
README.md CHANGED
@@ -2,37 +2,72 @@
2
  library_name: transformers
3
  tags: []
4
  ---
 
5
  # FastESM
 
6
 
7
- ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
8
 
9
- FastESM is a Huggingface compatible plug in version of ESM2-650M rewritten with a newer PyTorch attention implementation.
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
 
 
12
 
13
  Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
14
  Various other optimizations also make the base implementation slightly different than the one in transformers.
15
 
 
 
 
 
 
16
  ## Use with 🤗 transformers
 
 
17
  ```python
18
  import torch
19
- from transformers import AutoModelForMaskedLM, AutoTokenizer
20
 
21
  model_path = 'Synthyra/FastESM2_650'
22
- model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
23
- tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
24
 
25
  sequences = ['MPRTEIN', 'MSEQWENCE']
26
  tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
27
  with torch.no_grad():
28
- embeddings = model(**tokenized, output_hidden_states=True).hidden_states[-1]
29
 
30
  print(embeddings.shape) # (1, 11, 1280)
31
  ```
32
 
33
- Please note that FastESM does not currently work with AutoModel.
34
- If you would like to train a model from scratch without a language modeling head you can still use the base code, but if you load the weights with AutoModel they will not map correctly.
35
- AutoModelForSequenceClassification and AutoModelForTokenClassification are working as intended.
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  ## Embed entire datasets with no new code
38
  To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
@@ -60,6 +95,7 @@ _ = model.embed_dataset(
60
  sql_db_path='embeddings.db', # path to .db file of choice
61
  )
62
  ```
 
63
  ## Model probes
64
  We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
65
 
 
2
  library_name: transformers
3
  tags: []
4
  ---
5
+
6
  # FastESM
7
+ FastESM is a Huggingface compatible plug in version of ESM2 rewritten with a newer PyTorch attention implementation.
8
 
9
+ Load any ESM2 models into a FastEsm model to dramatically speed up training and inference without **ANY** cost in performance.
10
 
11
+ ## Use with 🤗 transformers
12
+ ```python
13
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoModelForSequenceClassification, AutoModelForTokenClassification # any of these work
14
+
15
+ model_dict = {
16
+ 'ESM2-8': 'facebook/esm2_t6_8M_UR50D',
17
+ 'ESM2-35': 'facebook/esm2_t12_35M_UR50D',
18
+ 'ESM2-150': 'facebook/esm2_t30_150M_UR50D',
19
+ 'ESM2-650': 'facebook/esm2_t33_650M_UR50D',
20
+ 'ESM2-3B': 'facebook/esm2_t36_3B_UR50D',
21
+ 'ESM2-15B': 'facebook/esm2_t48_15B_UR50D',
22
+ }
23
 
24
+ model = AutoModelForMaskedLM.from_pretrained(model_dict['ESM2-8'], trust_remote_code=True)
25
+ tokenizer = model.tokenizer
26
+ ```
27
 
28
  Outputting attention maps (or the contact prediction head) is not natively possible with SDPA. You can still pass ```output_attentions``` to have attention calculated manually and returned.
29
  Various other optimizations also make the base implementation slightly different than the one in transformers.
30
 
31
+ # FastESM2-650
32
+
33
+ ## A faster half-precision version of ESM2-650 with FlashAttention2 and longer context
34
+ To enhance the weights with longer context and better fp16 support, we trained ESM2-650 50000 additional steps with a traditional MLM objective (20% masking) in fp16 mixed precision on [OMGprot50](https://huggingface.co/datasets/tattabio/OMG_prot50) up to sequence length of **2048**.
35
+
36
  ## Use with 🤗 transformers
37
+
38
+ ### For working with embeddings
39
  ```python
40
  import torch
41
+ from transformers import AutoModel, AutoTokenizer
42
 
43
  model_path = 'Synthyra/FastESM2_650'
44
+ model = AutoModel.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
45
+ tokenizer = model.tokenizer
46
 
47
  sequences = ['MPRTEIN', 'MSEQWENCE']
48
  tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
49
  with torch.no_grad():
50
+ embeddings = model(**tokenized).last_hidden_state
51
 
52
  print(embeddings.shape) # (1, 11, 1280)
53
  ```
54
 
55
+ ### For working with sequence logits
56
+ ```python
57
+ import torch
58
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
59
+
60
+ model_path = 'Synthyra/FastESM2_650'
61
+ model = AutoModelForMaskedLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).eval()
62
+ tokenizer = model.tokenizer
63
+
64
+ sequences = ['MPRTEIN', 'MSEQWENCE']
65
+ tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
66
+ with torch.no_grad():
67
+ logits = model(**tokenized).logits
68
+
69
+ print(logits.shape) # (1, 11, 33)
70
+ ```
71
 
72
  ## Embed entire datasets with no new code
73
  To embed a list of protein sequences **fast**, just call embed_dataset. Sequences are sorted to reduce padding tokens, so the initial progress bar estimation is usually much longer than the actual time.
 
95
  sql_db_path='embeddings.db', # path to .db file of choice
96
  )
97
  ```
98
+
99
  ## Model probes
100
  We employ linear probing techniques on various PLMs and standard datasets, similar our previous [paper](https://www.biorxiv.org/content/10.1101/2024.07.30.605924v1), to assess the intrinsic correlation between pooled hidden states and valuable properties. FastESM performs very well.
101