frankaging commited on
Commit
e39562b
·
1 Parent(s): 1baa5c3
Files changed (1) hide show
  1. app.py +84 -135
app.py CHANGED
@@ -12,22 +12,8 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
12
  login(token=HF_TOKEN)
13
 
14
  MAX_MAX_NEW_TOKENS = 2048
15
- DEFAULT_MAX_NEW_TOKENS = 1024
16
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
-
18
- DESCRIPTION = """\
19
- # Model Steering with Supervised Dictionary Learning (SDL)
20
-
21
- ### What's Model Steering with SDL?
22
- This is a demo of model steering with AxBench-ReFT-r1-16K, ...
23
- """
24
-
25
- LICENSE = """
26
- <p/>
27
-
28
- ---
29
- Please refer to the specific licensing and use policy of the underlying model.
30
- """
31
 
32
  def load_jsonl(jsonl_path):
33
  jsonl_data = []
@@ -38,41 +24,41 @@ def load_jsonl(jsonl_path):
38
  return jsonl_data
39
 
40
  class Steer(pv.SourcelessIntervention):
41
- """Steer model via activation addition"""
42
  def __init__(self, **kwargs):
43
  super().__init__(**kwargs, keep_last_dim=True)
44
  self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False)
45
 
46
  def forward(self, base, source=None, subspaces=None):
47
- # subspaces is a list of dicts:
48
- # each has {"idx": int, "internal_mag": float, "text": str, ...}
49
  steer_vec = base
50
  if subspaces is not None:
51
  for sp in subspaces:
52
  idx = sp["idx"]
53
- mag = sp["internal_mag"] # the true scaling factor
54
  steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0)
55
  steer_vec = steer_vec + steering_vec
56
  return steer_vec
57
 
58
-
59
- # ------------------------------------------
60
- # Load the Model & Dictionary if GPU exists
61
- # ------------------------------------------
62
  if not torch.cuda.is_available():
63
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo won't perform well on CPU.</p>"
64
-
 
 
 
 
 
 
65
  if torch.cuda.is_available():
66
- model_id = "google/gemma-2-2b-it"
67
  model = AutoModelForCausalLM.from_pretrained(
68
  model_id, device_map="cuda", torch_dtype=torch.bfloat16
69
  )
70
  tokenizer = AutoTokenizer.from_pretrained(model_id)
71
 
72
- path_to_params = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
73
- path_to_md = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
74
- params = torch.load(path_to_params).cuda()
75
- md = load_jsonl(path_to_md)
 
76
 
77
  concept_list = [item["concept"] for item in md]
78
  concept_id_map = {item["concept"]: item["concept_id"] for item in md}
@@ -88,12 +74,8 @@ if torch.cuda.is_available():
88
  model=model,
89
  )
90
 
91
- terminators = [tokenizer.eos_token_id]
92
-
93
 
94
- # --------------------------------------------
95
- # Main generation function: keep last 3 turns
96
- # --------------------------------------------
97
  @spaces.GPU
98
  def generate(
99
  message: str,
@@ -101,37 +83,28 @@ def generate(
101
  max_new_tokens: int,
102
  subspaces_list: list[dict],
103
  ) -> Iterator[str]:
104
-
105
- # Restrict to the last 3 turns only
106
  start_idx = max(0, len(chat_history) - 3)
107
  recent_history = chat_history[start_idx:]
108
 
109
- # Convert (user_msg, model_msg) => list of messages
110
  messages = []
111
  for user_msg, model_msg in recent_history:
112
  messages.append({"role": "user", "content": user_msg})
113
  messages.append({"role": "model", "content": model_msg})
114
-
115
- # Add the new user message
116
  messages.append({"role": "user", "content": message})
117
 
118
- # Apply the chat template (some HF models expect "assistant" instead of "model")
119
- # but let's keep "model" to match your code, if that is required.
120
- prompt_dict = tokenizer.apply_chat_template(
121
- messages, tokenize=True, add_generation_prompt=True
122
- )
123
- input_ids = torch.tensor([prompt_dict["input_ids"]]).cuda()
124
- attention_mask = torch.tensor([prompt_dict["attention_mask"]]).cuda()
125
 
126
- # Possibly trim if too long
127
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
128
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
129
- attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
130
- yield "\n[Warning: Truncated conversation exceeds max allowed input tokens]\n"
131
 
132
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
133
  generate_kwargs = {
134
- "base": {"input_ids": input_ids, "attention_mask": attention_mask},
135
  "unit_locations": None,
136
  "max_new_tokens": max_new_tokens,
137
  "intervene_on_prompt": True,
@@ -150,29 +123,20 @@ def generate(
150
  partial_text.append(token_str)
151
  yield "".join(partial_text)
152
 
153
-
154
- # ----------------
155
- # UI Callbacks
156
- # ----------------
157
  def filter_concepts(search_text: str):
158
- """Return the first 500 concepts that match (case-insensitive)."""
159
  if not search_text.strip():
160
  return concept_list[:500]
161
  filtered = [c for c in concept_list if search_text.lower() in c.lower()]
162
  return filtered[:500]
163
 
164
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
165
- """
166
- user_slider_val is from [-5..5]. We multiply by 50 internally to get the real magnitude.
167
- """
168
  if not selected_concept:
169
  return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
170
-
171
- concept_idx = concept_id_map[selected_concept]
172
- internal_mag = user_slider_val * 50 # scale by 50
173
  new_entry = {
174
  "text": selected_concept,
175
- "idx": concept_idx,
176
  "display_mag": user_slider_val,
177
  "internal_mag": internal_mag,
178
  }
@@ -183,14 +147,10 @@ def add_concept_to_list(selected_concept, user_slider_val, current_list):
183
  gr.update(choices=_build_remove_choices(updated_list))
184
  )
185
 
186
- def remove_concept_from_list(concept_to_remove, current_list):
187
- """
188
- Remove the chosen concept name from the list.
189
- """
190
- if not concept_to_remove:
191
  return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
192
-
193
- updated_list = [x for x in current_list if x["text"] != concept_to_remove]
194
  return (
195
  updated_list,
196
  _build_table_data(updated_list),
@@ -198,115 +158,104 @@ def remove_concept_from_list(concept_to_remove, current_list):
198
  )
199
 
200
  def _build_table_data(subspaces):
201
- """Return [[concept_name, scaled_mag], ...] for display."""
202
  return [[x["text"], x["display_mag"]] for x in subspaces]
203
 
204
  def _build_remove_choices(subspaces):
205
- """Return concept names for the remove dropdown."""
206
  return [x["text"] for x in subspaces]
207
 
208
  def update_dropdown_choices(search_text):
209
  filtered = filter_concepts(search_text)
210
  return gr.update(choices=filtered)
211
 
212
- # --------------------------------------------------------------------
213
- # Build the Interface
214
- # --------------------------------------------------------------------
215
  with gr.Blocks(css="style.css") as demo:
216
- gr.Markdown(DESCRIPTION)
217
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
218
 
219
- # If GPU is available, pick a random concept as default
220
  default_subspaces = []
221
- if torch.cuda.is_available() and len(concept_list) > 0:
222
  default_concept = random.choice(concept_list)
223
  default_subspaces = [{
224
  "text": default_concept,
225
  "idx": concept_id_map[default_concept],
226
- "display_mag": 3, # user sees 3
227
- "internal_mag": 150.0, # actual factor
228
  }]
229
 
230
  selected_subspaces = gr.State(default_subspaces)
231
-
232
  with gr.Row():
233
- with gr.Column(scale=5):
234
- # Use type="messages" to avoid tuple-format deprecation warnings
235
  chat_interface = gr.ChatInterface(
236
  fn=generate,
237
- additional_inputs=[
238
- gr.Slider(
239
- label="Max new tokens",
240
- minimum=1,
241
- maximum=MAX_MAX_NEW_TOKENS,
242
- step=1,
243
- value=DEFAULT_MAX_NEW_TOKENS,
244
- ),
245
- selected_subspaces
246
- ],
247
- title="Model Steering with ReFT-r1 (16K concepts)",
248
- type="messages", # <--- uses openai-style 'role' and 'content'
249
  )
250
- with gr.Column(scale=4):
251
- gr.Markdown("## Steering Concepts")
252
-
253
  search_box = gr.Textbox(
254
  label="Search concepts",
255
- placeholder="Type text to filter concepts (e.g. 'sports')"
256
  )
257
  concept_dropdown = gr.Dropdown(
258
  label="Filtered Concepts",
259
- choices=[], # dynamically populated
260
- multiselect=False
261
  )
262
  concept_magnitude = gr.Slider(
263
- label="Scaled Magnitude (×50 internally)",
264
  minimum=-5,
265
  maximum=5,
266
  step=1,
267
  value=3
268
  )
269
  add_button = gr.Button("Add Concept")
270
-
271
- # Show the table of active subspaces
272
  active_subspaces_table = gr.Dataframe(
273
- headers=["Concept", "Magnitude (scaled)"],
274
  datatype=["str", "number"],
275
  value=_build_table_data(default_subspaces),
276
  interactive=False,
277
- label="Active Concept Subspaces"
 
278
  )
279
-
280
- # Remove concept by name
281
- remove_dropdown = gr.Dropdown(
282
- label="Remove a concept",
283
- choices=_build_remove_choices(default_subspaces),
284
- multiselect=False
285
- )
286
- remove_button = gr.Button("Remove Selected Concept")
287
-
288
- gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
 
289
 
290
  # Wire up events
291
- # Update concept dropdown when user types in search
292
- search_box.change(
293
- fn=update_dropdown_choices,
294
- inputs=[search_box],
295
- outputs=[concept_dropdown]
296
- )
297
-
298
- # Add concept
299
  add_button.click(
300
- fn=add_concept_to_list,
301
- inputs=[concept_dropdown, concept_magnitude, selected_subspaces],
302
- outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
303
  )
304
-
305
- # Remove a concept
306
  remove_button.click(
307
- fn=remove_concept_from_list,
308
- inputs=[remove_dropdown, selected_subspaces],
309
- outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
 
 
 
 
 
310
  )
311
 
312
- demo.queue(max_size=20).launch()
 
12
  login(token=HF_TOKEN)
13
 
14
  MAX_MAX_NEW_TOKENS = 2048
15
+ DEFAULT_MAX_NEW_TOKENS = 512 # smaller default to save memory
16
+ MAX_INPUT_TOKEN_LENGTH = 4096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def load_jsonl(jsonl_path):
19
  jsonl_data = []
 
24
  return jsonl_data
25
 
26
  class Steer(pv.SourcelessIntervention):
 
27
  def __init__(self, **kwargs):
28
  super().__init__(**kwargs, keep_last_dim=True)
29
  self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False)
30
 
31
  def forward(self, base, source=None, subspaces=None):
 
 
32
  steer_vec = base
33
  if subspaces is not None:
34
  for sp in subspaces:
35
  idx = sp["idx"]
36
+ mag = sp["internal_mag"] # scaled by 50
37
  steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0)
38
  steer_vec = steer_vec + steering_vec
39
  return steer_vec
40
 
41
+ # Check GPU
 
 
 
42
  if not torch.cuda.is_available():
43
+ print("Warning: Running on CPU, may be slow.")
44
+
45
+ # Load model & dictionary
46
+ model_id = "google/gemma-2-2b-it"
47
+ pv_model = None
48
+ tokenizer = None
49
+ concept_list = []
50
+ concept_id_map = {}
51
  if torch.cuda.is_available():
 
52
  model = AutoModelForCausalLM.from_pretrained(
53
  model_id, device_map="cuda", torch_dtype=torch.bfloat16
54
  )
55
  tokenizer = AutoTokenizer.from_pretrained(model_id)
56
 
57
+ # Download dictionary
58
+ weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
59
+ meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
60
+ params = torch.load(weight_path).cuda()
61
+ md = load_jsonl(meta_path)
62
 
63
  concept_list = [item["concept"] for item in md]
64
  concept_id_map = {item["concept"]: item["concept_id"] for item in md}
 
74
  model=model,
75
  )
76
 
77
+ terminators = [tokenizer.eos_token_id] if tokenizer else []
 
78
 
 
 
 
79
  @spaces.GPU
80
  def generate(
81
  message: str,
 
83
  max_new_tokens: int,
84
  subspaces_list: list[dict],
85
  ) -> Iterator[str]:
86
+ # limit to last 3 turns
 
87
  start_idx = max(0, len(chat_history) - 3)
88
  recent_history = chat_history[start_idx:]
89
 
90
+ # build list of messages
91
  messages = []
92
  for user_msg, model_msg in recent_history:
93
  messages.append({"role": "user", "content": user_msg})
94
  messages.append({"role": "model", "content": model_msg})
 
 
95
  messages.append({"role": "user", "content": message})
96
 
97
+ input_ids = torch.tensor([tokenizer.apply_chat_template(
98
+ messages, tokenize=True, add_generation_prompt=True)]).cuda()
 
 
 
 
 
99
 
100
+ # trim if needed
101
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
102
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
103
+ yield "[Truncated prior text]\n"
 
104
 
105
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
106
  generate_kwargs = {
107
+ "base": {"input_ids": input_ids},
108
  "unit_locations": None,
109
  "max_new_tokens": max_new_tokens,
110
  "intervene_on_prompt": True,
 
123
  partial_text.append(token_str)
124
  yield "".join(partial_text)
125
 
 
 
 
 
126
  def filter_concepts(search_text: str):
 
127
  if not search_text.strip():
128
  return concept_list[:500]
129
  filtered = [c for c in concept_list if search_text.lower() in c.lower()]
130
  return filtered[:500]
131
 
132
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
 
 
 
133
  if not selected_concept:
134
  return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
135
+ idx = concept_id_map[selected_concept]
136
+ internal_mag = user_slider_val * 50
 
137
  new_entry = {
138
  "text": selected_concept,
139
+ "idx": idx,
140
  "display_mag": user_slider_val,
141
  "internal_mag": internal_mag,
142
  }
 
147
  gr.update(choices=_build_remove_choices(updated_list))
148
  )
149
 
150
+ def remove_concept_from_list(selected_text, current_list):
151
+ if not selected_text:
 
 
 
152
  return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
153
+ updated_list = [x for x in current_list if x["text"] != selected_text]
 
154
  return (
155
  updated_list,
156
  _build_table_data(updated_list),
 
158
  )
159
 
160
  def _build_table_data(subspaces):
 
161
  return [[x["text"], x["display_mag"]] for x in subspaces]
162
 
163
  def _build_remove_choices(subspaces):
 
164
  return [x["text"] for x in subspaces]
165
 
166
  def update_dropdown_choices(search_text):
167
  filtered = filter_concepts(search_text)
168
  return gr.update(choices=filtered)
169
 
 
 
 
170
  with gr.Blocks(css="style.css") as demo:
171
+ # A short title only
172
+ gr.Markdown("## Model Steering with ReFT-r1 (16K concepts)")
173
 
174
+ # Pre-populate with a random concept if available
175
  default_subspaces = []
176
+ if pv_model and concept_list:
177
  default_concept = random.choice(concept_list)
178
  default_subspaces = [{
179
  "text": default_concept,
180
  "idx": concept_id_map[default_concept],
181
+ "display_mag": 3,
182
+ "internal_mag": 150.0,
183
  }]
184
 
185
  selected_subspaces = gr.State(default_subspaces)
 
186
  with gr.Row():
187
+ # Left side: bigger chat area
188
+ with gr.Column(scale=7):
189
  chat_interface = gr.ChatInterface(
190
  fn=generate,
191
+ additional_inputs=[], # we'll put the max tokens slider below
192
+ title="",
193
+ type="messages",
194
+ height=550 # a bit taller to show more conversation
 
 
 
 
 
 
 
 
195
  )
196
+ # Right side: concept management
197
+ with gr.Column(scale=3):
198
+ gr.Markdown("### Steering Concepts")
199
  search_box = gr.Textbox(
200
  label="Search concepts",
201
+ placeholder="e.g. 'time travel'"
202
  )
203
  concept_dropdown = gr.Dropdown(
204
  label="Filtered Concepts",
205
+ choices=[]
 
206
  )
207
  concept_magnitude = gr.Slider(
208
+ label="Scaled Factor",
209
  minimum=-5,
210
  maximum=5,
211
  step=1,
212
  value=3
213
  )
214
  add_button = gr.Button("Add Concept")
 
 
215
  active_subspaces_table = gr.Dataframe(
216
+ headers=["Concept", "Mag (scaled)"],
217
  datatype=["str", "number"],
218
  value=_build_table_data(default_subspaces),
219
  interactive=False,
220
+ label="Active Concept Subspaces",
221
+ height=170 # give it a bit more room
222
  )
223
+ # Row with the remove dropdown + button
224
+ with gr.Row():
225
+ remove_dropdown = gr.Dropdown(
226
+ label="Remove concept",
227
+ choices=_build_remove_choices(default_subspaces),
228
+ multiselect=False
229
+ )
230
+ remove_button = gr.Button("Remove", variant="secondary")
231
+
232
+ # Place the max tokens slider at bottom, smaller
233
+ with gr.Row():
234
+ gr.Markdown("**Max New Tokens**", elem_classes=["small-label"])
235
+ max_token_slider = gr.Slider(
236
+ minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1,
237
+ value=DEFAULT_MAX_NEW_TOKENS,
238
+ label="", # hide the big label
239
+ container=False,
240
+ style={"width": "30%"} # narrower
241
+ )
242
 
243
  # Wire up events
244
+ search_box.change(update_dropdown_choices, [search_box], [concept_dropdown])
 
 
 
 
 
 
 
245
  add_button.click(
246
+ add_concept_to_list,
247
+ [concept_dropdown, concept_magnitude, selected_subspaces],
248
+ [selected_subspaces, active_subspaces_table, remove_dropdown]
249
  )
 
 
250
  remove_button.click(
251
+ remove_concept_from_list,
252
+ [remove_dropdown, selected_subspaces],
253
+ [selected_subspaces, active_subspaces_table, remove_dropdown]
254
+ )
255
+
256
+ # Link the slider back to chat generation
257
+ chat_interface.configure(
258
+ extra_inputs=[max_token_slider, selected_subspaces]
259
  )
260
 
261
+ demo.launch()