frankaging commited on
Commit
bddba98
·
1 Parent(s): e3ab52c
Files changed (1) hide show
  1. app.py +39 -22
app.py CHANGED
@@ -13,7 +13,7 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
13
  login(token=HF_TOKEN)
14
 
15
  MAX_MAX_NEW_TOKENS = 2048
16
- DEFAULT_MAX_NEW_TOKENS = 512 # smaller default to save memory
17
  MAX_INPUT_TOKEN_LENGTH = 4096
18
 
19
  def load_jsonl(jsonl_path):
@@ -29,7 +29,8 @@ class Steer(pv.SourcelessIntervention):
29
  def __init__(self, **kwargs):
30
  super().__init__(**kwargs, keep_last_dim=True)
31
  self.proj = torch.nn.Linear(
32
- self.embed_dim, kwargs["latent_dim"], bias=False)
 
33
  def forward(self, base, source=None, subspaces=None):
34
  steering_vec = torch.tensor(subspaces["mag"]) * \
35
  self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
@@ -82,9 +83,9 @@ def generate(
82
 
83
  # build list of messages
84
  messages = []
85
- for user_msg, model_msg in recent_history:
86
- messages.append({"role": "user", "content": user_msg})
87
- messages.append({"role": "model", "content": model_msg})
88
  messages.append({"role": "user", "content": message})
89
 
90
  input_ids = torch.tensor([tokenizer.apply_chat_template(
@@ -101,7 +102,12 @@ def generate(
101
  "unit_locations": None,
102
  "max_new_tokens": max_new_tokens,
103
  "intervene_on_prompt": True,
104
- "subspaces": [{"idx": int(subspaces_list[0]["idx"]), "mag": int(subspaces_list[0]["internal_mag"])}],
 
 
 
 
 
105
  "streamer": streamer,
106
  "do_sample": True
107
  }
@@ -121,8 +127,14 @@ def filter_concepts(search_text: str):
121
  return filtered[:500]
122
 
123
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
 
 
 
 
 
124
  if not selected_concept:
125
- return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
 
126
  idx = concept_id_map[selected_concept]
127
  internal_mag = user_slider_val * 50
128
  new_entry = {
@@ -132,24 +144,18 @@ def add_concept_to_list(selected_concept, user_slider_val, current_list):
132
  "internal_mag": internal_mag,
133
  }
134
  updated_list = current_list + [new_entry]
135
- return (
136
- updated_list,
137
- _build_table_data(updated_list),
138
- gr.update(choices=_build_remove_choices(updated_list))
139
- )
140
 
141
  def remove_concept_from_list(selected_text, current_list):
 
 
 
 
 
142
  if not selected_text:
143
- return current_list, _build_table_data(current_list), gr.update(choices=_build_remove_choices(current_list))
144
  updated_list = [x for x in current_list if x["text"] != selected_text]
145
- return (
146
- updated_list,
147
- _build_table_data(updated_list),
148
- gr.update(choices=_build_remove_choices(updated_list))
149
- )
150
-
151
- def _build_table_data(subspaces):
152
- return [[x["text"], x["display_mag"]] for x in subspaces]
153
 
154
  def _build_remove_choices(subspaces):
155
  return [x["text"] for x in subspaces]
@@ -211,12 +217,23 @@ with gr.Blocks(css="style.css") as demo:
211
  remove_button = gr.Button("Remove", variant="secondary")
212
 
213
  # Wire up events
214
- search_box.change(update_dropdown_choices, [search_box], [concept_dropdown])
 
 
 
 
 
 
 
 
215
  add_button.click(
216
  add_concept_to_list,
217
  [concept_dropdown, concept_magnitude, selected_subspaces],
218
  [selected_subspaces, remove_dropdown]
219
  )
 
 
 
220
  remove_button.click(
221
  remove_concept_from_list,
222
  [remove_dropdown, selected_subspaces],
 
13
  login(token=HF_TOKEN)
14
 
15
  MAX_MAX_NEW_TOKENS = 2048
16
+ DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory
17
  MAX_INPUT_TOKEN_LENGTH = 4096
18
 
19
  def load_jsonl(jsonl_path):
 
29
  def __init__(self, **kwargs):
30
  super().__init__(**kwargs, keep_last_dim=True)
31
  self.proj = torch.nn.Linear(
32
+ self.embed_dim, kwargs["latent_dim"], bias=False
33
+ )
34
  def forward(self, base, source=None, subspaces=None):
35
  steering_vec = torch.tensor(subspaces["mag"]) * \
36
  self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
 
83
 
84
  # build list of messages
85
  messages = []
86
+ # for user_msg, model_msg in recent_history:
87
+ # messages.append({"role": "user", "content": user_msg})
88
+ # messages.append({"role": "model", "content": model_msg})
89
  messages.append({"role": "user", "content": message})
90
 
91
  input_ids = torch.tensor([tokenizer.apply_chat_template(
 
102
  "unit_locations": None,
103
  "max_new_tokens": max_new_tokens,
104
  "intervene_on_prompt": True,
105
+ "subspaces": [
106
+ {
107
+ "idx": int(subspaces_list[0]["idx"]),
108
+ "mag": int(subspaces_list[0]["internal_mag"])
109
+ }
110
+ ] if subspaces_list else [],
111
  "streamer": streamer,
112
  "do_sample": True
113
  }
 
127
  return filtered[:500]
128
 
129
  def add_concept_to_list(selected_concept, user_slider_val, current_list):
130
+ """
131
+ Return exactly 2 values:
132
+ 1) The updated list of concepts (list of dicts).
133
+ 2) A Gradio update for the removal dropdown’s choices.
134
+ """
135
  if not selected_concept:
136
+ return current_list, gr.update(choices=_build_remove_choices(current_list))
137
+
138
  idx = concept_id_map[selected_concept]
139
  internal_mag = user_slider_val * 50
140
  new_entry = {
 
144
  "internal_mag": internal_mag,
145
  }
146
  updated_list = current_list + [new_entry]
147
+ return updated_list, gr.update(choices=_build_remove_choices(updated_list))
 
 
 
 
148
 
149
  def remove_concept_from_list(selected_text, current_list):
150
+ """
151
+ Return exactly 2 values:
152
+ 1) The updated list of concepts (list of dicts).
153
+ 2) A Gradio update for the removal dropdown’s choices.
154
+ """
155
  if not selected_text:
156
+ return current_list, gr.update(choices=_build_remove_choices(current_list))
157
  updated_list = [x for x in current_list if x["text"] != selected_text]
158
+ return updated_list, gr.update(choices=_build_remove_choices(updated_list))
 
 
 
 
 
 
 
159
 
160
  def _build_remove_choices(subspaces):
161
  return [x["text"] for x in subspaces]
 
217
  remove_button = gr.Button("Remove", variant="secondary")
218
 
219
  # Wire up events
220
+ # When the search box changes, update the concept dropdown choices:
221
+ search_box.change(
222
+ update_dropdown_choices,
223
+ [search_box],
224
+ [concept_dropdown]
225
+ )
226
+
227
+ # When "Add Concept" is clicked, add the concept + magnitude to the list,
228
+ # and update the "Remove" dropdown choices.
229
  add_button.click(
230
  add_concept_to_list,
231
  [concept_dropdown, concept_magnitude, selected_subspaces],
232
  [selected_subspaces, remove_dropdown]
233
  )
234
+
235
+ # When "Remove" is clicked, remove the selected concept from the list,
236
+ # and update the "Remove" dropdown choices.
237
  remove_button.click(
238
  remove_concept_from_list,
239
  [remove_dropdown, selected_subspaces],