Update README.md
Browse files
README.md
CHANGED
@@ -27,9 +27,13 @@ ESM_DIM = 1280
|
|
27 |
SAE_DIM = 4096
|
28 |
LAYER = 24
|
29 |
|
|
|
|
|
30 |
# Load ESM model
|
31 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
32 |
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
|
|
|
|
33 |
|
34 |
# Load SAE model
|
35 |
checkpoint_path = hf_hub_download(
|
@@ -38,11 +42,13 @@ checkpoint_path = hf_hub_download(
|
|
38 |
)
|
39 |
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
|
40 |
sae_model.load_state_dict(load_file(checkpoint_path))
|
|
|
|
|
41 |
```
|
42 |
|
43 |
ESM -> SAE inference on an amino acid sequence of length `L`
|
44 |
```
|
45 |
-
seq = "
|
46 |
|
47 |
# Tokenize sequence and run ESM inference
|
48 |
inputs = tokenizer(seq, padding=True, return_tensors="pt")
|
|
|
27 |
SAE_DIM = 4096
|
28 |
LAYER = 24
|
29 |
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
|
32 |
# Load ESM model
|
33 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
34 |
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
35 |
+
esm_model.to(device)
|
36 |
+
esm_model.eval()
|
37 |
|
38 |
# Load SAE model
|
39 |
checkpoint_path = hf_hub_download(
|
|
|
42 |
)
|
43 |
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
|
44 |
sae_model.load_state_dict(load_file(checkpoint_path))
|
45 |
+
sae_model.to(device)
|
46 |
+
sae_model.eval()
|
47 |
```
|
48 |
|
49 |
ESM -> SAE inference on an amino acid sequence of length `L`
|
50 |
```
|
51 |
+
seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN"
|
52 |
|
53 |
# Tokenize sequence and run ESM inference
|
54 |
inputs = tokenizer(seq, padding=True, return_tensors="pt")
|