liambai commited on
Commit
ed39825
·
verified ·
1 Parent(s): c139215

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +42 -1
README.md CHANGED
@@ -3,4 +3,45 @@ license: apache-2.0
3
  ---
4
  # InterProt ESM2 SAE Models
5
 
6
- A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using protein sequences from [UniProt](https://www.uniprot.org/). Check out the [InterProt website](https://interprot.com/) for an interactive visualizer of the SAE features.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
4
  # InterProt ESM2 SAE Models
5
 
6
+ A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using protein sequences from [UniProt](https://www.uniprot.org/). The [InterProt website](https://interprot.com/) has an interactive visualizer of the SAE features.
7
+
8
+ ## Installation
9
+
10
+ ```bash
11
+ pip install git+https://github.com/etowahadams/interprot.git
12
+ ```
13
+
14
+ ## Usage
15
+
16
+ ### Load SAE
17
+ ```python
18
+ from safetensors.torch import load_file
19
+ from interprot.sae_model import SparseAutoencoder
20
+
21
+ sae_model = SparseAutoencoder(1280, 4096)
22
+ checkpoint_path = 'esm2_plm1280_l24_sae4096.safetensors'
23
+ sae_model.load_state_dict(load_file(checkpoint_path))
24
+ ```
25
+
26
+ ### ESM -> SAE Inference
27
+ ```
28
+ import torch
29
+ from transformers import AutoTokenizer, EsmModel
30
+
31
+ # Load ESM model and tokenizer
32
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
33
+ esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
34
+
35
+ # Run ESM inference with some sequence and take layer 24 activations
36
+ seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG"
37
+ esm_layer = 24
38
+
39
+ inputs = tokenizer([seq], padding=True, return_tensors="pt")
40
+ with torch.no_grad():
41
+ outputs = esm_model(**inputs, output_hidden_states=True)
42
+ esm_layer_acts = outputs.hidden_states[esm_layer] # (1, sequence length + 2, 1280)
43
+
44
+ # Run SAE inference with ESM activations as input
45
+ sae_acts = sae_model.get_acts(esm_layer_acts)
46
+ sae_acts # (1, sequence length + 2, 4096)
47
+ ```