File size: 2,999 Bytes
aa1790d 8eab21d 3af2c7f 1883f2b ed39825 2a8a50f ed39825 2a8a50f ed39825 2a8a50f ed39825 2a8a50f ed39825 b91dcc7 2a8a50f ed39825 b91dcc7 ed39825 2a8a50f b91dcc7 2a8a50f b91dcc7 ed39825 2a8a50f 388d00f ed39825 2a8a50f ed39825 388d00f 3ded2e5 388d00f 3ded2e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
---
license: apache-2.0
---
# InterProt ESM2 SAE Models
A set of SAE models trained on [ESM2-650](https://huggingface.co/facebook/esm2_t33_650M_UR50D) activations using 1M protein sequences from [UniProt](https://www.uniprot.org/). The SAE implementation mostly followed [Gao et al.](https://arxiv.org/abs/2406.04093) with Top-K activation function.
For more information, check out our [preprint](https://www.biorxiv.org/content/10.1101/2025.02.06.636901v1). Our SAEs can be viewed and interacted with on [interprot.com](https://interprot.com).
## Installation
```bash
pip install git+https://github.com/etowahadams/interprot.git
```
## Usage
Install InterProt, load ESM and SAE
```python
import torch
from transformers import AutoTokenizer, EsmModel
from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder
from huggingface_hub import hf_hub_download
ESM_DIM = 1280
SAE_DIM = 4096
LAYER = 24
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load ESM model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model.to(device)
esm_model.eval()
# Load SAE model
checkpoint_path = hf_hub_download(
repo_id="liambai/InterProt-ESM2-SAEs",
filename="esm2_plm1280_l24_sae4096.safetensors"
)
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
sae_model.load_state_dict(load_file(checkpoint_path))
sae_model.to(device)
sae_model.eval()
```
ESM -> SAE inference on an amino acid sequence of length `L`
```
seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN"
# Tokenize sequence and run ESM inference
inputs = tokenizer(seq, padding=True, return_tensors="pt").to(device)
with torch.no_grad():
outputs = esm_model(**inputs, output_hidden_states=True)
# esm_layer_acts has shape (L+2, ESM_DIM), +2 for BoS and EoS tokens
esm_layer_acts = outputs.hidden_states[LAYER][0]
# Using ESM embeddings from LAYER, run SAE inference
sae_acts = sae_model.get_acts(esm_layer_acts) # (L+2, SAE_DIM)
sae_acts
```
## Note on the default checkpoint on [interprot.com](https://interprot.com)
In Novermber 2024, we shared an earlier version of our layer 24 SAE on [X](https://x.com/liambai21/status/1852765669080879108?s=46) and got a lot of amazing community support in identifying SAE features; therefore, we have kept it as the default on [interprot.com](interprot.com). Since then, we retrained the layer 24 SAE with slightly different hyperparameters and on more sequences (1M vs. the original 100K). The new SAE is named `esm2_plm1280_l24_sae4096.safetensors` whereas the original is named `esm2_plm1280_l24_sae4096_100k.safetensors`.
We recommend using `esm2_plm1280_l24_sae4096.safetensors`, but if you'd like to reproduce the default SAE on [interprot.com](https://interprot.com), you can use `esm2_plm1280_l24_sae4096_100k.safetensors`. All other layer SAEs are trained with the same configrations as `esm2_plm1280_l24_sae4096.safetensors`. |