frankaging commited on
Commit
0fb9f4b
·
1 Parent(s): f9cd90a
Files changed (1) hide show
  1. app.py +57 -42
app.py CHANGED
@@ -14,7 +14,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
14
  login(token=HF_TOKEN)
15
 
16
  MAX_MAX_NEW_TOKENS = 2048
17
- DEFAULT_MAX_NEW_TOKENS = 256 # smaller default to save memory
18
  MAX_INPUT_TOKEN_LENGTH = 4096
19
 
20
  css = """
@@ -27,6 +27,13 @@ css = """
27
  border-radius: 4px;
28
  font-weight: 500;
29
  }
 
 
 
 
 
 
 
30
  """
31
 
32
  def load_jsonl(jsonl_path):
@@ -212,7 +219,6 @@ def generate(
212
  }
213
  ] if steering_list else None, # if steering is not provided, we do not steer.
214
  "streamer": streamer,
215
- "repetition_penalty": 1.5,
216
  "do_sample": True
217
  }
218
 
@@ -252,87 +258,96 @@ def add_concept_to_list(selected_concept, user_slider_val, current_list):
252
  current_list = [new_entry]
253
  return current_list
254
 
255
- def update_dropdown_choices(search_text):
256
  filtered = filter_concepts(search_text)
257
  if not filtered or len(filtered) == 0:
258
- return gr.update(choices=[f"[New] {search_text}"], value=f"[New] {search_text}", interactive=True), gr.Textbox(
259
- label="No matching existing concepts were found!",
260
- value="Good news! Based on the concept you provided, we will automatically generate a steering vector. Try it out by starting a chat!",
261
- lines=3,
262
- interactive=False,
263
- visible=True,
264
- elem_id="alert-message"
265
- )
266
- # Automatically select the first matching concept
 
 
 
 
 
 
 
 
 
 
267
  return gr.update(
268
  choices=filtered,
269
- value=filtered[0], # Select the first match
270
- interactive=True, visible=True
 
271
  ), gr.Textbox(visible=False)
272
 
273
  with gr.Blocks(css=css, fill_height=True) as demo:
274
- # States for both detection and steering
275
  selected_detection = gr.State([])
276
  selected_subspaces = gr.State([])
277
 
278
- with gr.Row(min_height=1000):
279
  # Left side: chat area
280
  with gr.Column(scale=7):
281
  chat_interface = gr.ChatInterface(
282
  fn=generate,
283
- title="Chat with a Concept Steering Model",
284
- description="""You can only steer the model when a concept is detected internally. Select concepts on the right →\n\nWe intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20.""",
285
- type="messages",
 
 
 
 
286
  additional_inputs=[selected_detection, selected_subspaces],
287
  fill_height=True,
288
- css=".gradio-chatbot {min-height: 1500px;}"
289
  )
290
 
291
  # Right side: concept detection and steering
292
  with gr.Column(scale=3):
293
- # Concept Detection Panel
294
- # gr.Markdown("## Detect then Steer")
295
- gr.Markdown("Select a concept to detect. We will only steer the model when this concept is detected internally.")
296
- with gr.Group():
297
  detect_search = gr.Textbox(
298
- label="Search Detection Concepts",
299
- placeholder="Find concepts to detect (e.g. 'Google')",
300
  lines=1,
301
  )
302
  detect_msg = gr.TextArea(visible=False)
303
  detect_dropdown = gr.Dropdown(
304
- label="Select concept to detect",
305
  interactive=True,
306
  allow_custom_value=False,
307
  )
308
  detect_threshold = gr.Slider(
309
- label="Detection Threshold",
310
  minimum=0,
311
  maximum=1,
312
- step=0.01,
313
  value=0.5,
314
  )
315
 
316
- # Divider
317
- # gr.Markdown("---")
318
 
319
- # Steering Panel (existing)
320
- # gr.Markdown("## Steer Response")
321
- gr.Markdown("Select a concept to steer when detection occurs.")
322
- with gr.Group():
323
  search_box = gr.Textbox(
324
- label="Search Steering Concepts",
325
- placeholder="Find concepts to steer the model (e.g. 'ethics and morality')",
326
  lines=1,
327
  )
328
  msg = gr.TextArea(visible=False)
329
  concept_dropdown = gr.Dropdown(
330
- label="Select concept to steer",
331
  interactive=True,
332
  allow_custom_value=False,
333
  )
334
  concept_magnitude = gr.Slider(
335
- label="Steering Intensity",
336
  minimum=-5,
337
  maximum=5,
338
  step=0.1,
@@ -341,7 +356,7 @@ with gr.Blocks(css=css, fill_height=True) as demo:
341
 
342
  # Wire up events for detection
343
  detect_search.input(
344
- update_dropdown_choices,
345
  [detect_search],
346
  [detect_dropdown, detect_msg]
347
  ).then(
@@ -362,9 +377,9 @@ with gr.Blocks(css=css, fill_height=True) as demo:
362
  [selected_detection]
363
  )
364
 
365
- # Wire up events for steering (existing)
366
  search_box.input(
367
- update_dropdown_choices,
368
  [search_box],
369
  [concept_dropdown, msg]
370
  ).then(
 
14
  login(token=HF_TOKEN)
15
 
16
  MAX_MAX_NEW_TOKENS = 2048
17
+ DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory
18
  MAX_INPUT_TOKEN_LENGTH = 4096
19
 
20
  css = """
 
27
  border-radius: 4px;
28
  font-weight: 500;
29
  }
30
+
31
+ .concept-help {
32
+ font-size: 0.9em;
33
+ color: #666;
34
+ margin-top: 4px;
35
+ font-style: italic;
36
+ }
37
  """
38
 
39
  def load_jsonl(jsonl_path):
 
219
  }
220
  ] if steering_list else None, # if steering is not provided, we do not steer.
221
  "streamer": streamer,
 
222
  "do_sample": True
223
  }
224
 
 
258
  current_list = [new_entry]
259
  return current_list
260
 
261
+ def update_dropdown_choices(search_text, is_detection=False):
262
  filtered = filter_concepts(search_text)
263
  if not filtered or len(filtered) == 0:
264
+ alert_message = (
265
+ "Good news! Based on the topic you provided, we will automatically generate a detector for you!"
266
+ ) if is_detection else (
267
+ "Good news! Based on the topic you provided, we will automatically generate a steering vector. Try it out by starting a chat!"
268
+ )
269
+
270
+ return gr.update(
271
+ choices=[],
272
+ value=None,
273
+ interactive=True
274
+ ), gr.Textbox(
275
+ label="No matching topics found",
276
+ value=alert_message,
277
+ lines=3,
278
+ interactive=False,
279
+ visible=True,
280
+ elem_id="alert-message"
281
+ )
282
+
283
  return gr.update(
284
  choices=filtered,
285
+ value=filtered[0],
286
+ interactive=True,
287
+ visible=True
288
  ), gr.Textbox(visible=False)
289
 
290
  with gr.Blocks(css=css, fill_height=True) as demo:
 
291
  selected_detection = gr.State([])
292
  selected_subspaces = gr.State([])
293
 
294
+ with gr.Row(min_height=500, equal_height=True):
295
  # Left side: chat area
296
  with gr.Column(scale=7):
297
  chat_interface = gr.ChatInterface(
298
  fn=generate,
299
+ title="Conditionally Steer AI Responses Based on Topics",
300
+ description="""This is an experimental chatbot that you can steer using topics you care about:
301
+
302
+ Step 1: Choose a topic to detect (e.g., "Google")
303
+ Step 2: Choose a topic you want the model to discuss when the previous topic comes up (e.g., "ethics")
304
+
305
+ Try it out! For example, set it to detect "Google" topics and steer toward discussing "ethics". We intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20.""",
306
  additional_inputs=[selected_detection, selected_subspaces],
307
  fill_height=True,
 
308
  )
309
 
310
  # Right side: concept detection and steering
311
  with gr.Column(scale=3):
312
+ gr.Markdown("""#### Step 1: Choose a topic you want to recognize.""")
313
+ with gr.Group():
 
 
314
  detect_search = gr.Textbox(
315
+ label="Search for topics to detect",
316
+ placeholder="Try: 'Google'",
317
  lines=1,
318
  )
319
  detect_msg = gr.TextArea(visible=False)
320
  detect_dropdown = gr.Dropdown(
321
+ label="Choose a topic to detect (Click to see more!)",
322
  interactive=True,
323
  allow_custom_value=False,
324
  )
325
  detect_threshold = gr.Slider(
326
+ label="Detection sensitivity",
327
  minimum=0,
328
  maximum=1,
329
+ step=0.1,
330
  value=0.5,
331
  )
332
 
333
+ gr.Markdown("---")
 
334
 
335
+ gr.Markdown("""#### Step 2: Choose another topic you want to discuss when it detects the chosen topic above.""")
336
+
337
+ with gr.Group():
 
338
  search_box = gr.Textbox(
339
+ label="Search topics to steer",
340
+ placeholder="Try: 'ethics'",
341
  lines=1,
342
  )
343
  msg = gr.TextArea(visible=False)
344
  concept_dropdown = gr.Dropdown(
345
+ label="Choose a topic to steer the model (Click to see more!)",
346
  interactive=True,
347
  allow_custom_value=False,
348
  )
349
  concept_magnitude = gr.Slider(
350
+ label="Steering intensity",
351
  minimum=-5,
352
  maximum=5,
353
  step=0.1,
 
356
 
357
  # Wire up events for detection
358
  detect_search.input(
359
+ lambda x: update_dropdown_choices(x, is_detection=True),
360
  [detect_search],
361
  [detect_dropdown, detect_msg]
362
  ).then(
 
377
  [selected_detection]
378
  )
379
 
380
+ # Wire up events for steering
381
  search_box.input(
382
+ lambda x: update_dropdown_choices(x, is_detection=False),
383
  [search_box],
384
  [concept_dropdown, msg]
385
  ).then(