frankaging commited on
Commit
f860e61
·
1 Parent(s): 7065c79
Files changed (1) hide show
  1. app.py +192 -69
app.py CHANGED
@@ -1,37 +1,32 @@
1
- # login as a privileged user.
2
  import os, json
3
- HF_TOKEN = os.environ.get("HF_TOKEN")
4
-
5
- from huggingface_hub import login, hf_hub_download
6
- login(token=HF_TOKEN)
7
-
8
- from threading import Thread
9
- from typing import Iterator
10
-
11
  import gradio as gr
12
  import spaces
13
- import torch
14
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
15
  import pyvene as pv
 
 
16
 
 
 
17
 
18
  MAX_MAX_NEW_TOKENS = 2048
19
  DEFAULT_MAX_NEW_TOKENS = 1024
20
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
21
 
22
-
23
  DESCRIPTION = """\
24
  # Model Steering with Supervised Dictionary Learning (SDL)
25
 
26
  ### What's Model Steering with SDL?
27
- This is a demo of model steering with Supervised Dictionary Learning (SDL) using AxBench-ReFT-r1-16K, which hosts steering vectors for 16K concepts. We evaluate various steering methods, including ReFT-r1, a novel weakly-supervised dictionary learning method. ReFT-r1 demonstrates competitive steering capabilities compared to finetuning and prompting baselines.
28
  """
29
 
30
  LICENSE = """
31
  <p/>
32
 
33
  ---
34
- This demo is governed by the original license and acceptable use policy of the model it is derived from. Please refer to the specific licensing and use policy of the underlying model.
35
  """
36
 
37
  def load_jsonl(jsonl_path):
@@ -39,81 +34,112 @@ def load_jsonl(jsonl_path):
39
  with open(jsonl_path, 'r') as f:
40
  for line in f:
41
  data = json.loads(line)
42
- jsonl_data += [data]
43
  return jsonl_data
44
 
45
-
46
  class Steer(pv.SourcelessIntervention):
47
  """Steer model via activation addition"""
48
  def __init__(self, **kwargs):
49
  super().__init__(**kwargs, keep_last_dim=True)
50
- self.proj = torch.nn.Linear(
51
- self.embed_dim, kwargs["latent_dim"], bias=False)
52
- def forward(self, base, source=None, subspaces=None):
53
- steering_vec = torch.tensor(subspaces["mag"]) * \
54
- self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
55
- return base + steering_vec
56
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
 
58
  if not torch.cuda.is_available():
59
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
60
-
61
 
62
  if torch.cuda.is_available():
63
- # load the LLM
64
  model_id = "google/gemma-2-2b-it"
65
  model = AutoModelForCausalLM.from_pretrained(
66
  model_id, device_map="cuda", torch_dtype=torch.bfloat16
67
  )
68
  tokenizer = AutoTokenizer.from_pretrained(model_id)
69
 
70
- # load the dictionary
71
- path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt", force_download=False)
72
- path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl", force_download=False)
73
  params = torch.load(path_to_params).cuda()
74
  md = load_jsonl(path_to_md)
75
- id_to_concept = {item["concept_id"]: item["concept"] for item in md}
76
  concept_list = [item["concept"] for item in md]
 
77
 
78
  steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
79
  steer.proj.weight.data = params.float()
80
 
81
- # Mount the encoder to the model
82
- pv_model = pv.IntervenableModel({
83
- "component": f"model.layers[20].output",
84
- "intervention": steer}, model=model)
 
 
 
85
 
86
- terminators = [
87
- tokenizer.eos_token_id,
88
- ]
89
 
90
 
 
 
 
 
91
  @spaces.GPU
92
  def generate(
93
  message: str,
94
  chat_history: list[tuple[str, str]],
95
- max_new_tokens: int = 1024,
 
96
  ) -> Iterator[str]:
97
 
98
- # tokenize and prepare the input
99
- prompt = torch.tensor([tokenizer.apply_chat_template(
100
- [{"role": "user", "content": message}], tokenize=True, add_generation_prompt=True)]).cuda()
101
-
102
- input_ids = prompt["input_ids"]
103
- attention_mask = prompt["attention_mask"]
104
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
106
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
107
  attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
108
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
109
-
110
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
111
  generate_kwargs = {
112
  "base": {"input_ids": input_ids, "attention_mask": attention_mask},
113
  "unit_locations": None,
114
  "max_new_tokens": max_new_tokens,
115
  "intervene_on_prompt": True,
116
- "subspaces": [{"idx": 1795, "mag": 150.0}],
117
  "streamer": streamer,
118
  "eos_token_id": terminators,
119
  "early_stopping": True,
@@ -123,33 +149,130 @@ def generate(
123
  t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
124
  t.start()
125
 
126
- outputs = []
127
- for text in streamer:
128
- outputs.append(text)
129
- yield "".join(outputs)
130
-
131
-
132
- chat_interface = gr.ChatInterface(
133
- fn=generate,
134
- additional_inputs=[
135
- gr.Slider(
136
- label="Max new tokens",
137
- minimum=1,
138
- maximum=MAX_MAX_NEW_TOKENS,
139
- step=1,
140
- value=DEFAULT_MAX_NEW_TOKENS,
141
- )
142
- ],
143
- stop_btn=None,
144
- title="Model Steering with ReFT-r1 (16K concepts)",
145
- )
 
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  with gr.Blocks(css="style.css") as demo:
148
  gr.Markdown(DESCRIPTION)
149
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
150
- chat_interface.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  gr.Markdown(LICENSE)
152
 
153
- if __name__ == "__main__":
154
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
155
 
 
 
 
 
 
 
 
 
 
1
  import os, json
2
+ import torch
 
 
 
 
 
 
 
3
  import gradio as gr
4
  import spaces
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ from huggingface_hub import login, hf_hub_download
7
  import pyvene as pv
8
+ from threading import Thread
9
+ from typing import Iterator
10
 
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
12
+ login(token=HF_TOKEN)
13
 
14
  MAX_MAX_NEW_TOKENS = 2048
15
  DEFAULT_MAX_NEW_TOKENS = 1024
16
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
 
 
18
  DESCRIPTION = """\
19
  # Model Steering with Supervised Dictionary Learning (SDL)
20
 
21
  ### What's Model Steering with SDL?
22
+ This is a demo of model steering with AxBench-ReFT-r1-16K, ...
23
  """
24
 
25
  LICENSE = """
26
  <p/>
27
 
28
  ---
29
+ Please refer to the specific licensing and use policy of the underlying model.
30
  """
31
 
32
  def load_jsonl(jsonl_path):
 
34
  with open(jsonl_path, 'r') as f:
35
  for line in f:
36
  data = json.loads(line)
37
+ jsonl_data.append(data)
38
  return jsonl_data
39
 
 
40
  class Steer(pv.SourcelessIntervention):
41
  """Steer model via activation addition"""
42
  def __init__(self, **kwargs):
43
  super().__init__(**kwargs, keep_last_dim=True)
44
+ self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False)
 
 
 
 
 
45
 
46
+ def forward(self, base, source=None, subspaces=None):
47
+ # subspaces is a list of dicts: each has {"idx": int, "mag": float}
48
+ steer_vec = base
49
+ if subspaces is not None:
50
+ for sp in subspaces:
51
+ idx = sp["idx"]
52
+ mag = sp["mag"]
53
+ # each idx is a row in self.proj.weight
54
+ steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0)
55
+ steer_vec = steer_vec + steering_vec
56
+ return steer_vec
57
 
58
+ # ---------------------------------------------------
59
+ # Load Model & Dictionary if GPU is available
60
+ # ---------------------------------------------------
61
  if not torch.cuda.is_available():
62
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo won't perform well on CPU.</p>"
 
63
 
64
  if torch.cuda.is_available():
 
65
  model_id = "google/gemma-2-2b-it"
66
  model = AutoModelForCausalLM.from_pretrained(
67
  model_id, device_map="cuda", torch_dtype=torch.bfloat16
68
  )
69
  tokenizer = AutoTokenizer.from_pretrained(model_id)
70
 
71
+ path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
72
+ path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
 
73
  params = torch.load(path_to_params).cuda()
74
  md = load_jsonl(path_to_md)
75
+
76
  concept_list = [item["concept"] for item in md]
77
+ concept_id_map = {item["concept"]: item["concept_id"] for item in md}
78
 
79
  steer = Steer(embed_dim=params.shape[0], latent_dim=params.shape[1])
80
  steer.proj.weight.data = params.float()
81
 
82
+ pv_model = pv.IntervenableModel(
83
+ {
84
+ "component": f"model.layers[20].output",
85
+ "intervention": steer,
86
+ },
87
+ model=model,
88
+ )
89
 
90
+ terminators = [tokenizer.eos_token_id]
 
 
91
 
92
 
93
+ # ---------------------------------------------------------------------
94
+ # The main generation function, limiting to last 3 conversation turns
95
+ # and then using apply_chat_template
96
+ # ---------------------------------------------------------------------
97
  @spaces.GPU
98
  def generate(
99
  message: str,
100
  chat_history: list[tuple[str, str]],
101
+ max_new_tokens: int,
102
+ subspaces_list: list[dict],
103
  ) -> Iterator[str]:
104
 
105
+ # Restrict to the last 3 turns only
106
+ start_idx = max(0, len(chat_history) - 3)
107
+ recent_history = chat_history[start_idx:]
108
+
109
+ # Build a list of messages
110
+ # each tuple is (user_message, assistant_message)
111
+ messages = []
112
+ for user_msg, assistant_msg in recent_history:
113
+ messages.append({"role": "user", "content": user_msg})
114
+ messages.append({"role": "assistant", "content": assistant_msg})
115
+
116
+ # Now append the new user message
117
+ messages.append({"role": "user", "content": message})
118
+
119
+ # Convert messages into model input tokens with a generation prompt
120
+ prompt = tokenizer.apply_chat_template(
121
+ messages,
122
+ tokenize=True,
123
+ add_generation_prompt=True # appends a final "Assistant:" for the model to continue
124
+ )
125
+
126
+ # Retrieve input_ids and mask
127
+ input_ids = torch.tensor([prompt["input_ids"]]).cuda()
128
+ attention_mask = torch.tensor([prompt["attention_mask"]]).cuda()
129
+
130
+ # Possibly trim if over max length
131
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
132
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
133
  attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
134
+ yield "\n[Warning: Truncated conversation exceeds max allowed input tokens]\n"
135
+
136
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
137
  generate_kwargs = {
138
  "base": {"input_ids": input_ids, "attention_mask": attention_mask},
139
  "unit_locations": None,
140
  "max_new_tokens": max_new_tokens,
141
  "intervene_on_prompt": True,
142
+ "subspaces": subspaces_list,
143
  "streamer": streamer,
144
  "eos_token_id": terminators,
145
  "early_stopping": True,
 
149
  t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
150
  t.start()
151
 
152
+ partial_text = []
153
+ for token_str in streamer:
154
+ partial_text.append(token_str)
155
+ yield "".join(partial_text)
156
+
157
+
158
+ # --------------
159
+ # UI Callbacks
160
+ # --------------
161
+ def filter_concepts(search_text: str):
162
+ if not search_text.strip():
163
+ return concept_list[:500]
164
+ filtered = [c for c in concept_list if search_text.lower() in c.lower()]
165
+ return filtered[:500]
166
+
167
+ def add_concept_to_list(selected_concept, magnitude, current_list):
168
+ """When 'Add Concept' is clicked, add the chosen concept and magnitude to subspaces."""
169
+ if not selected_concept:
170
+ return current_list, current_list, gr.update(choices=[str(x["idx"]) for x in current_list])
171
+ concept_idx = concept_id_map[selected_concept]
172
+ new_entry = {"idx": concept_idx, "mag": magnitude}
173
+ updated_list = current_list + [new_entry]
174
 
175
+ remove_choices = [str(x["idx"]) for x in updated_list]
176
+ table_data = [[x['idx'], x['mag']] for x in updated_list]
177
+ return updated_list, table_data, gr.update(choices=remove_choices)
178
+
179
+ def remove_concept_from_list(rem_concept_idx_str, current_list):
180
+ """Remove the chosen concept from the list. Index is a string from remove_dropdown."""
181
+ if not rem_concept_idx_str:
182
+ return current_list, current_list, gr.update()
183
+ rem_idx = int(rem_concept_idx_str)
184
+ updated_list = [x for x in current_list if x["idx"] != rem_idx]
185
+ remove_choices = [str(x["idx"]) for x in updated_list]
186
+ table_data = [[x['idx'], x['mag']] for x in updated_list]
187
+ return updated_list, table_data, gr.update(choices=remove_choices)
188
+
189
+ def update_dropdown_choices(search_text):
190
+ filtered = filter_concepts(search_text)
191
+ return gr.update(choices=filtered)
192
+
193
+
194
+ # -------------------------
195
+ # Build the Gradio Blocks
196
+ # -------------------------
197
  with gr.Blocks(css="style.css") as demo:
198
  gr.Markdown(DESCRIPTION)
199
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
200
+
201
+ selected_subspaces = gr.State([])
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ # Searching / selecting a concept
206
+ search_box = gr.Textbox(
207
+ label="Search concepts",
208
+ placeholder="Type text to filter concepts (e.g. 'sports')"
209
+ )
210
+ concept_dropdown = gr.Dropdown(
211
+ label="Filtered Concepts",
212
+ choices=[],
213
+ multiselect=False
214
+ )
215
+ concept_magnitude = gr.Slider(
216
+ label="Magnitude",
217
+ minimum=-300.0,
218
+ maximum=300.0,
219
+ step=1.0,
220
+ value=150.0
221
+ )
222
+ add_button = gr.Button("Add Concept")
223
+
224
+ # Removal
225
+ remove_dropdown = gr.Dropdown(
226
+ label="Remove from active list",
227
+ choices=[],
228
+ multiselect=False
229
+ )
230
+ remove_button = gr.Button("Remove Selected")
231
+
232
+ with gr.Column():
233
+ # Display currently active subspaces
234
+ active_subspaces_table = gr.Dataframe(
235
+ headers=["idx", "magnitude"],
236
+ datatype=["number", "number"],
237
+ interactive=False,
238
+ label="Active Concept Subspaces"
239
+ )
240
+
241
+ # The Chat Interface
242
+ chat_interface = gr.ChatInterface(
243
+ fn=generate,
244
+ additional_inputs=[
245
+ gr.Slider(
246
+ label="Max new tokens",
247
+ minimum=1,
248
+ maximum=MAX_MAX_NEW_TOKENS,
249
+ step=1,
250
+ value=DEFAULT_MAX_NEW_TOKENS,
251
+ ),
252
+ selected_subspaces
253
+ ],
254
+ title="Model Steering with ReFT-r1 (16K concepts)",
255
+ )
256
+
257
  gr.Markdown(LICENSE)
258
 
259
+ # Wire up events
260
+ search_box.change(
261
+ fn=update_dropdown_choices,
262
+ inputs=[search_box],
263
+ outputs=[concept_dropdown]
264
+ )
265
+
266
+ add_button.click(
267
+ fn=add_concept_to_list,
268
+ inputs=[concept_dropdown, concept_magnitude, selected_subspaces],
269
+ outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
270
+ )
271
 
272
+ remove_button.click(
273
+ fn=remove_concept_from_list,
274
+ inputs=[remove_dropdown, selected_subspaces],
275
+ outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
276
+ )
277
+
278
+ demo.queue(max_size=20).launch()