liambai commited on
Commit
2a8a50f
·
verified ·
1 Parent(s): 54f5bbf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +27 -18
README.md CHANGED
@@ -15,35 +15,44 @@ pip install git+https://github.com/etowahadams/interprot.git
15
 
16
  ## Usage
17
 
18
- Load the SAE
19
  ```python
 
 
20
  from safetensors.torch import load_file
21
  from interprot.sae_model import SparseAutoencoder
 
22
 
23
- sae_model = SparseAutoencoder(1280, 4096)
24
- checkpoint_path = 'esm2_plm1280_l24_sae4096.safetensors'
25
- sae_model.load_state_dict(load_file(checkpoint_path))
26
- ```
27
 
28
- Load ESM and run ESM inference -> SAE inference
29
- ```
30
- import torch
31
- from transformers import AutoTokenizer, EsmModel
32
-
33
- # Load ESM model and tokenizer
34
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
35
  esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
36
 
37
- # Run ESM inference with some sequence and take layer 24 activations
 
 
 
 
 
 
 
 
 
 
38
  seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG"
39
- esm_layer = 24
40
 
41
- inputs = tokenizer([seq], padding=True, return_tensors="pt")
 
42
  with torch.no_grad():
43
  outputs = esm_model(**inputs, output_hidden_states=True)
44
- esm_layer_acts = outputs.hidden_states[esm_layer] # (1, sequence length + 2, 1280)
45
 
46
- # Run SAE inference with ESM activations as input
47
- sae_acts = sae_model.get_acts(esm_layer_acts)
48
- sae_acts # (1, sequence length + 2, 4096)
 
 
 
49
  ```
 
15
 
16
  ## Usage
17
 
18
+ Install InterProt, load ESM and SAE
19
  ```python
20
+ import torch
21
+ from transformers import AutoTokenizer, EsmModel
22
  from safetensors.torch import load_file
23
  from interprot.sae_model import SparseAutoencoder
24
+ from huggingface_hub import hf_hub_download
25
 
26
+ ESM_DIM = 1280
27
+ SAE_DIM = 4096
28
+ LAYER = 24
 
29
 
30
+ # Load ESM model
 
 
 
 
 
31
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
32
  esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
33
 
34
+ # Load SAE model
35
+ checkpoint_path = hf_hub_download(
36
+ repo_id="liambai/InterProt-ESM2-SAEs",
37
+ filename="esm2_plm1280_l24_sae4096.safetensors"
38
+ )
39
+ sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
40
+ sae_model.load_state_dict(load_file(checkpoint_path))
41
+ ```
42
+
43
+ ESM -> SAE inference on an amino acid sequence of length `L`
44
+ ```
45
  seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG"
 
46
 
47
+ # Tokenize sequence and run ESM inference
48
+ inputs = tokenizer(seq, padding=True, return_tensors="pt")
49
  with torch.no_grad():
50
  outputs = esm_model(**inputs, output_hidden_states=True)
 
51
 
52
+ # esm_layer_acts has shape (L+2, ESM_DIM), +2 for BoS and EoS tokens
53
+ esm_layer_acts = outputs.hidden_states[LAYER][0]
54
+
55
+ # Using ESM embeddings from LAYER, run SAE inference
56
+ sae_acts = sae_model.get_acts(esm_layer_acts) # (L+2, SAE_DIM)
57
+ sae_acts
58
  ```