File size: 2,999 Bytes
aa1790d
 
 
 
 
8eab21d
3af2c7f
1883f2b
ed39825
 
 
 
 
 
 
 
 
2a8a50f
ed39825
2a8a50f
 
ed39825
 
2a8a50f
ed39825
2a8a50f
 
 
ed39825
b91dcc7
 
2a8a50f
ed39825
 
b91dcc7
 
ed39825
2a8a50f
 
 
 
 
 
 
b91dcc7
 
2a8a50f
 
 
 
b91dcc7
ed39825
2a8a50f
388d00f
ed39825
 
 
2a8a50f
 
 
 
 
 
ed39825
388d00f
3ded2e5
388d00f
 
 
3ded2e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
---
license: apache-2.0
---
# InterProt ESM2 SAE Models

A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using 1M protein sequences from [UniProt](https://www.uniprot.org/). The SAE implementation mostly followed [Gao et al.](https://arxiv.org/abs/2406.04093) with Top-K activation function.

For more information, check out our [preprint](https://www.biorxiv.org/content/10.1101/2025.02.06.636901v1). Our SAEs can be viewed and interacted with on [interprot.com](https://interprot.com).

## Installation

```bash
pip install git+https://github.com/etowahadams/interprot.git
```

## Usage

Install InterProt, load ESM and SAE
```python
import torch
from transformers import AutoTokenizer, EsmModel
from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder
from huggingface_hub import hf_hub_download

ESM_DIM = 1280
SAE_DIM = 4096
LAYER = 24

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ESM model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model.to(device)
esm_model.eval()

# Load SAE model
checkpoint_path = hf_hub_download(
    repo_id="liambai/InterProt-ESM2-SAEs",
    filename="esm2_plm1280_l24_sae4096.safetensors"
)
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
sae_model.load_state_dict(load_file(checkpoint_path))
sae_model.to(device)
sae_model.eval()
```

ESM -> SAE inference on an amino acid sequence of length `L`
```
seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN"

# Tokenize sequence and run ESM inference
inputs = tokenizer(seq, padding=True, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = esm_model(**inputs, output_hidden_states=True)

# esm_layer_acts has shape (L+2, ESM_DIM), +2 for BoS and EoS tokens
esm_layer_acts = outputs.hidden_states[LAYER][0]

# Using ESM embeddings from LAYER, run SAE inference
sae_acts = sae_model.get_acts(esm_layer_acts) # (L+2, SAE_DIM)
sae_acts
```

## Note on the default checkpoint on [interprot.com](https://interprot.com)

In Novermber 2024, we shared an earlier version of our layer 24 SAE on [X](https://x.com/liambai21/status/1852765669080879108?s=46) and got a lot of amazing community support in identifying SAE features; therefore, we have kept it as the default on [interprot.com](interprot.com). Since then, we retrained the layer 24 SAE with slightly different hyperparameters and on more sequences (1M vs. the original 100K). The new SAE is named `esm2_plm1280_l24_sae4096.safetensors` whereas the original is named `esm2_plm1280_l24_sae4096_100k.safetensors`.

We recommend using `esm2_plm1280_l24_sae4096.safetensors`, but if you'd like to reproduce the default SAE on [interprot.com](https://interprot.com), you can use `esm2_plm1280_l24_sae4096_100k.safetensors`. All other layer SAEs are trained with the same configrations as `esm2_plm1280_l24_sae4096.safetensors`.