liambai commited on
Commit
b91dcc7
·
verified ·
1 Parent(s): 821ad1e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -1
README.md CHANGED
@@ -27,9 +27,13 @@ ESM_DIM = 1280
27
  SAE_DIM = 4096
28
  LAYER = 24
29
 
 
 
30
  # Load ESM model
31
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
32
  esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
 
 
33
 
34
  # Load SAE model
35
  checkpoint_path = hf_hub_download(
@@ -38,11 +42,13 @@ checkpoint_path = hf_hub_download(
38
  )
39
  sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
40
  sae_model.load_state_dict(load_file(checkpoint_path))
 
 
41
  ```
42
 
43
  ESM -> SAE inference on an amino acid sequence of length `L`
44
  ```
45
- seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVVAAIVQDIAYLRSLGYNIVATPRGYVLAGG"
46
 
47
  # Tokenize sequence and run ESM inference
48
  inputs = tokenizer(seq, padding=True, return_tensors="pt")
 
27
  SAE_DIM = 4096
28
  LAYER = 24
29
 
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
  # Load ESM model
33
  tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
34
  esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
35
+ esm_model.to(device)
36
+ esm_model.eval()
37
 
38
  # Load SAE model
39
  checkpoint_path = hf_hub_download(
 
42
  )
43
  sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
44
  sae_model.load_state_dict(load_file(checkpoint_path))
45
+ sae_model.to(device)
46
+ sae_model.eval()
47
  ```
48
 
49
  ESM -> SAE inference on an amino acid sequence of length `L`
50
  ```
51
+ seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN"
52
 
53
  # Tokenize sequence and run ESM inference
54
  inputs = tokenizer(seq, padding=True, return_tensors="pt")