zhengxuanzenwu commited on
Commit
5c9603a
·
verified ·
1 Parent(s): d52a5d4

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -1
README.md CHANGED
@@ -20,8 +20,15 @@ It is a single dictionary of subspaces for 16K concepts and serves as a drop-in
20
 
21
  ```python
22
  from huggingface_hub import hf_hub_download
 
 
23
  import pyvene as pv
24
 
 
 
 
 
 
25
  # Create an intervention.
26
  class Encoder(pv.CollectIntervention):
27
  """An intervention that reads concept latent from streams"""
@@ -33,7 +40,8 @@ class Encoder(pv.CollectIntervention):
33
  return torch.relu(self.proj(base))
34
 
35
  # Loading weights
36
- path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
 
37
  encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1])
38
  encoder.proj.weight.data = params.float()
39
 
@@ -48,6 +56,7 @@ input_ids = torch.tensor([tokenizer.apply_chat_template(
48
  [{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)]).cuda()
49
  acts = pv_model.forward(
50
  {"input_ids": input_ids}, return_dict=True).collected_activations[0]
 
51
  ```
52
 
53
  # 4. Point of Contact
 
20
 
21
  ```python
22
  from huggingface_hub import hf_hub_download
23
+ from transformers import AutoModelForCausalLM, AutoTokenizer
24
+ import torch
25
  import pyvene as pv
26
 
27
+ # Load model and tokenizer
28
+ model_name = "google/gemma-2-2b-it"
29
+ model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
30
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
31
+
32
  # Create an intervention.
33
  class Encoder(pv.CollectIntervention):
34
  """An intervention that reads concept latent from streams"""
 
40
  return torch.relu(self.proj(base))
41
 
42
  # Loading weights
43
+ path_to_params = hf_hub_download(repo_id="pyvene/gemma-diffmean-9b-it-res", filename="l20/weight.pt")
44
+ params = torch.load(path_to_params)
45
  encoder = Encoder(embed_dim=params.shape[0], latent_dim=params.shape[1])
46
  encoder.proj.weight.data = params.float()
47
 
 
56
  [{"role": "user", "content": prompt}], tokenize=True, add_generation_prompt=True)]).cuda()
57
  acts = pv_model.forward(
58
  {"input_ids": input_ids}, return_dict=True).collected_activations[0]
59
+
60
  ```
61
 
62
  # 4. Point of Contact