frankaging commited on
Commit
7497e24
·
1 Parent(s): f860e61
Files changed (1) hide show
  1. app.py +114 -77
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, json
2
  import torch
3
  import gradio as gr
4
  import spaces
@@ -44,13 +44,14 @@ class Steer(pv.SourcelessIntervention):
44
  self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False)
45
 
46
  def forward(self, base, source=None, subspaces=None):
47
- # subspaces is a list of dicts: each has {"idx": int, "mag": float}
 
48
  steer_vec = base
49
  if subspaces is not None:
50
  for sp in subspaces:
51
  idx = sp["idx"]
52
- mag = sp["mag"]
53
- # each idx is a row in self.proj.weight
54
  steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0)
55
  steer_vec = steer_vec + steering_vec
56
  return steer_vec
@@ -107,39 +108,30 @@ def generate(
107
  recent_history = chat_history[start_idx:]
108
 
109
  # Build a list of messages
110
- # each tuple is (user_message, assistant_message)
111
  messages = []
112
- for user_msg, assistant_msg in recent_history:
113
  messages.append({"role": "user", "content": user_msg})
114
- messages.append({"role": "assistant", "content": assistant_msg})
115
 
116
  # Now append the new user message
117
  messages.append({"role": "user", "content": message})
118
 
119
- # Convert messages into model input tokens with a generation prompt
120
- prompt = tokenizer.apply_chat_template(
121
- messages,
122
- tokenize=True,
123
- add_generation_prompt=True # appends a final "Assistant:" for the model to continue
124
- )
125
-
126
- # Retrieve input_ids and mask
127
- input_ids = torch.tensor([prompt["input_ids"]]).cuda()
128
- attention_mask = torch.tensor([prompt["attention_mask"]]).cuda()
129
 
130
  # Possibly trim if over max length
131
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
132
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
133
- attention_mask = attention_mask[:, -MAX_INPUT_TOKEN_LENGTH:]
134
  yield "\n[Warning: Truncated conversation exceeds max allowed input tokens]\n"
135
 
136
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
137
  generate_kwargs = {
138
- "base": {"input_ids": input_ids, "attention_mask": attention_mask},
139
  "unit_locations": None,
140
  "max_new_tokens": max_new_tokens,
141
  "intervene_on_prompt": True,
142
- "subspaces": subspaces_list,
143
  "streamer": streamer,
144
  "eos_token_id": terminators,
145
  "early_stopping": True,
@@ -159,32 +151,62 @@ def generate(
159
  # UI Callbacks
160
  # --------------
161
  def filter_concepts(search_text: str):
 
162
  if not search_text.strip():
163
  return concept_list[:500]
164
  filtered = [c for c in concept_list if search_text.lower() in c.lower()]
165
  return filtered[:500]
166
 
167
- def add_concept_to_list(selected_concept, magnitude, current_list):
168
- """When 'Add Concept' is clicked, add the chosen concept and magnitude to subspaces."""
 
 
 
 
 
169
  if not selected_concept:
170
- return current_list, current_list, gr.update(choices=[str(x["idx"]) for x in current_list])
 
171
  concept_idx = concept_id_map[selected_concept]
172
- new_entry = {"idx": concept_idx, "mag": magnitude}
173
- updated_list = current_list + [new_entry]
174
 
175
- remove_choices = [str(x["idx"]) for x in updated_list]
176
- table_data = [[x['idx'], x['mag']] for x in updated_list]
177
- return updated_list, table_data, gr.update(choices=remove_choices)
 
 
 
 
 
 
 
 
178
 
179
- def remove_concept_from_list(rem_concept_idx_str, current_list):
180
- """Remove the chosen concept from the list. Index is a string from remove_dropdown."""
181
- if not rem_concept_idx_str:
182
- return current_list, current_list, gr.update()
183
- rem_idx = int(rem_concept_idx_str)
184
- updated_list = [x for x in current_list if x["idx"] != rem_idx]
185
- remove_choices = [str(x["idx"]) for x in updated_list]
186
- table_data = [[x['idx'], x['mag']] for x in updated_list]
187
- return updated_list, table_data, gr.update(choices=remove_choices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def update_dropdown_choices(search_text):
190
  filtered = filter_concepts(search_text)
@@ -198,81 +220,96 @@ with gr.Blocks(css="style.css") as demo:
198
  gr.Markdown(DESCRIPTION)
199
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
200
 
201
- selected_subspaces = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  with gr.Row():
204
- with gr.Column():
205
- # Searching / selecting a concept
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  search_box = gr.Textbox(
207
  label="Search concepts",
208
  placeholder="Type text to filter concepts (e.g. 'sports')"
209
  )
210
  concept_dropdown = gr.Dropdown(
211
  label="Filtered Concepts",
212
- choices=[],
213
  multiselect=False
214
  )
215
  concept_magnitude = gr.Slider(
216
- label="Magnitude",
217
- minimum=-300.0,
218
- maximum=300.0,
219
  step=1.0,
220
- value=150.0
221
  )
222
  add_button = gr.Button("Add Concept")
223
 
224
- # Removal
225
- remove_dropdown = gr.Dropdown(
226
- label="Remove from active list",
227
- choices=[],
228
- multiselect=False
229
- )
230
- remove_button = gr.Button("Remove Selected")
231
-
232
- with gr.Column():
233
- # Display currently active subspaces
234
  active_subspaces_table = gr.Dataframe(
235
- headers=["idx", "magnitude"],
236
- datatype=["number", "number"],
237
  interactive=False,
238
- label="Active Concept Subspaces"
 
 
239
  )
240
 
241
- # The Chat Interface
242
- chat_interface = gr.ChatInterface(
243
- fn=generate,
244
- additional_inputs=[
245
- gr.Slider(
246
- label="Max new tokens",
247
- minimum=1,
248
- maximum=MAX_MAX_NEW_TOKENS,
249
- step=1,
250
- value=DEFAULT_MAX_NEW_TOKENS,
251
- ),
252
- selected_subspaces
253
- ],
254
- title="Model Steering with ReFT-r1 (16K concepts)",
255
- )
256
 
257
  gr.Markdown(LICENSE)
258
 
259
  # Wire up events
 
260
  search_box.change(
261
  fn=update_dropdown_choices,
262
  inputs=[search_box],
263
  outputs=[concept_dropdown]
264
  )
265
 
 
266
  add_button.click(
267
  fn=add_concept_to_list,
268
  inputs=[concept_dropdown, concept_magnitude, selected_subspaces],
269
- outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
270
  )
271
 
 
272
  remove_button.click(
273
- fn=remove_concept_from_list,
274
- inputs=[remove_dropdown, selected_subspaces],
275
- outputs=[selected_subspaces, active_subspaces_table, remove_dropdown],
276
  )
277
 
278
  demo.queue(max_size=20).launch()
 
1
+ import os, json, random
2
  import torch
3
  import gradio as gr
4
  import spaces
 
44
  self.proj = torch.nn.Linear(self.embed_dim, kwargs["latent_dim"], bias=False)
45
 
46
  def forward(self, base, source=None, subspaces=None):
47
+ # subspaces is a list of dicts:
48
+ # each has {"idx": int, "internal_mag": float, ...}
49
  steer_vec = base
50
  if subspaces is not None:
51
  for sp in subspaces:
52
  idx = sp["idx"]
53
+ # Use the internal magnitude for actual steering
54
+ mag = sp["internal_mag"]
55
  steering_vec = mag * self.proj.weight[idx].unsqueeze(dim=0)
56
  steer_vec = steer_vec + steering_vec
57
  return steer_vec
 
108
  recent_history = chat_history[start_idx:]
109
 
110
  # Build a list of messages
111
+ # each tuple is (user_message, model_message)
112
  messages = []
113
+ for user_msg, model_msg in recent_history:
114
  messages.append({"role": "user", "content": user_msg})
115
+ messages.append({"role": "model", "content": model_msg})
116
 
117
  # Now append the new user message
118
  messages.append({"role": "user", "content": message})
119
 
120
+ input_ids = torch.tensor([tokenizer.apply_chat_template(
121
+ messages, tokenize=True, add_generation_prompt=True)]).cuda()
 
 
 
 
 
 
 
 
122
 
123
  # Possibly trim if over max length
124
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
125
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
126
  yield "\n[Warning: Truncated conversation exceeds max allowed input tokens]\n"
127
 
128
  streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
129
  generate_kwargs = {
130
+ "base": {"input_ids": input_ids},
131
  "unit_locations": None,
132
  "max_new_tokens": max_new_tokens,
133
  "intervene_on_prompt": True,
134
+ "subspaces": subspaces_list, # pass entire structure, using "internal_mag"
135
  "streamer": streamer,
136
  "eos_token_id": terminators,
137
  "early_stopping": True,
 
151
  # UI Callbacks
152
  # --------------
153
  def filter_concepts(search_text: str):
154
+ """Return the first ~500 concepts that match (case-insensitive)."""
155
  if not search_text.strip():
156
  return concept_list[:500]
157
  filtered = [c for c in concept_list if search_text.lower() in c.lower()]
158
  return filtered[:500]
159
 
160
+ def add_concept_to_list(selected_concept, user_slider_val, current_list):
161
+ """
162
+ When 'Add Concept' is clicked, add the chosen concept with the
163
+ scaled magnitude to the subspaces list.
164
+
165
+ user_slider_val is from [-5..5], we multiply by 50 internally.
166
+ """
167
  if not selected_concept:
168
+ return current_list, _build_table_data(current_list)
169
+
170
  concept_idx = concept_id_map[selected_concept]
 
 
171
 
172
+ # Multiply slider by 50 internally
173
+ internal_mag = user_slider_val * 50
174
+
175
+ # We'll store both displayed magnitude (for the table) and the internal
176
+ # magnitude for the model. Also store 'text' for easy display.
177
+ new_entry = {
178
+ "text": selected_concept,
179
+ "idx": concept_idx,
180
+ "display_mag": user_slider_val,
181
+ "internal_mag": internal_mag,
182
+ }
183
 
184
+ # Avoid duplicates if you prefer:
185
+ # e.g. check if concept_idx already in current_list. We'll skip that for now.
186
+ updated_list = current_list + [new_entry]
187
+ return updated_list, _build_table_data(updated_list)
188
+
189
+ def remove_selected_row(selected_rows, current_list):
190
+ """
191
+ Removes the row selected from the table.
192
+ selected_rows is a list of selected row indices,
193
+ e.g. [1] meaning row with index 1 is selected.
194
+ """
195
+ if not selected_rows:
196
+ return current_list, _build_table_data(current_list)
197
+ row_idx = selected_rows[0] # single selection
198
+ # Safely remove if in range
199
+ if 0 <= row_idx < len(current_list):
200
+ updated_list = current_list[:row_idx] + current_list[row_idx+1:]
201
+ return updated_list, _build_table_data(updated_list)
202
+ else:
203
+ return current_list, _build_table_data(current_list)
204
+
205
+ def _build_table_data(subspaces):
206
+ """
207
+ Build a list of [concept_text, display_mag] to show in the table.
208
+ """
209
+ return [[x["text"], x["display_mag"]] for x in subspaces]
210
 
211
  def update_dropdown_choices(search_text):
212
  filtered = filter_concepts(search_text)
 
220
  gr.Markdown(DESCRIPTION)
221
  gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
222
 
223
+ # If GPU is available, define a random default concept:
224
+ default_subspaces = []
225
+ if torch.cuda.is_available() and len(concept_list) > 0:
226
+ default_index = random.randint(0, len(concept_list) - 1)
227
+ default_concept = concept_list[default_index]
228
+ default_concept_idx = concept_id_map[default_concept]
229
+ # default slider is 3 => 3*50=150 internally
230
+ default_subspaces = [{
231
+ "text": default_concept,
232
+ "idx": default_concept_idx,
233
+ "display_mag": 3, # what user sees
234
+ "internal_mag": 150.0, # actual scaling
235
+ }]
236
+
237
+ # Keep state of subspaces
238
+ selected_subspaces = gr.State(default_subspaces)
239
 
240
  with gr.Row():
241
+ # Left column: Chat
242
+ with gr.Column(scale=5):
243
+ chat_interface = gr.ChatInterface(
244
+ fn=generate,
245
+ additional_inputs=[
246
+ gr.Slider(
247
+ label="Max new tokens",
248
+ minimum=1,
249
+ maximum=MAX_MAX_NEW_TOKENS,
250
+ step=1,
251
+ value=DEFAULT_MAX_NEW_TOKENS,
252
+ ),
253
+ selected_subspaces # pass the entire subspaces list
254
+ ],
255
+ title="Model Steering with ReFT-r1 (16K concepts)",
256
+ )
257
+
258
+ # Right column: concept searching, adding, table display, removal
259
+ with gr.Column(scale=4):
260
+ gr.Markdown("## Steering Concepts")
261
  search_box = gr.Textbox(
262
  label="Search concepts",
263
  placeholder="Type text to filter concepts (e.g. 'sports')"
264
  )
265
  concept_dropdown = gr.Dropdown(
266
  label="Filtered Concepts",
267
+ choices=[], # dynamically populated
268
  multiselect=False
269
  )
270
  concept_magnitude = gr.Slider(
271
+ label="Scaled Magnitude (multiplies by 50 internally)",
272
+ minimum=-5,
273
+ maximum=5,
274
  step=1.0,
275
+ value=3
276
  )
277
  add_button = gr.Button("Add Concept")
278
 
279
+ # Current subspaces table
 
 
 
 
 
 
 
 
 
280
  active_subspaces_table = gr.Dataframe(
281
+ headers=["Concept", "Magnitude (scaled)"],
282
+ datatype=["str", "number"],
283
  interactive=False,
284
+ row_selectable="single",
285
+ label="Active Concept Subspaces",
286
+ value=_build_table_data(default_subspaces)
287
  )
288
 
289
+ remove_button = gr.Button("Remove Selected Row")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  gr.Markdown(LICENSE)
292
 
293
  # Wire up events
294
+ # Whenever user types in search_box, update concept_dropdown
295
  search_box.change(
296
  fn=update_dropdown_choices,
297
  inputs=[search_box],
298
  outputs=[concept_dropdown]
299
  )
300
 
301
+ # Add concept
302
  add_button.click(
303
  fn=add_concept_to_list,
304
  inputs=[concept_dropdown, concept_magnitude, selected_subspaces],
305
+ outputs=[selected_subspaces, active_subspaces_table],
306
  )
307
 
308
+ # Remove selected row from table
309
  remove_button.click(
310
+ fn=remove_selected_row,
311
+ inputs=[active_subspaces_table, selected_subspaces],
312
+ outputs=[selected_subspaces, active_subspaces_table],
313
  )
314
 
315
  demo.queue(max_size=20).launch()