dar-tau commited on
Commit
de099ae
1 Parent(s): af61663

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -53,12 +53,13 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
53
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
54
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
55
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
56
- return [hidden_states, *token_btns]
 
57
 
58
 
59
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
60
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
61
- num_beams=1, progress=gr.Progress()):
62
 
63
  interpreted_vectors = global_state[:, i]
64
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
@@ -82,7 +83,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
82
  # generate the interpretations
83
  generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
84
  generation_texts = tokenizer.batch_decode(generated)
85
- return [gr.Textbox(text, visible=True, container=False) for text in generation_texts]
 
86
 
87
 
88
  ## main
@@ -182,15 +184,18 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
182
  for i in range(MAX_PROMPT_TOKENS):
183
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
184
  tokens_container.append(btn)
185
- interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
186
- for i in range(model.config.num_hidden_layers)]
187
-
 
 
188
  for i, btn in enumerate(tokens_container):
189
- btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt, num_tokens, do_sample, temperature,
 
190
  top_k, top_p, repetition_penalty, length_penalty
191
  ], [*interpretation_bubbles])
192
 
193
  original_prompt_btn.click(get_hidden_states,
194
  [original_prompt_raw],
195
- [global_state, *tokens_container])
196
  demo.launch()
 
53
  hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
54
  token_btns = ([gr.Button(token, visible=True) for token in tokens]
55
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
56
+ progress_dummy_output = ''
57
+ return [progress_dummy_output, hidden_states, *token_btns]
58
 
59
 
60
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
61
  temperature, top_k, top_p, repetition_penalty, length_penalty, i,
62
+ num_beams=1):
63
 
64
  interpreted_vectors = global_state[:, i]
65
  length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
 
83
  # generate the interpretations
84
  generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
85
  generation_texts = tokenizer.batch_decode(generated)
86
+ progress_dummy_output = ''
87
+ return [progress_dummy_output] + [gr.Textbox(text, visible=True, container=False) for text in generation_texts]
88
 
89
 
90
  ## main
 
184
  for i in range(MAX_PROMPT_TOKENS):
185
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
186
  tokens_container.append(btn)
187
+
188
+ progress_dummy = gr.Text('', container=False)
189
+ interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
190
+ for i in range(model.config.num_hidden_layers)]
191
+
192
  for i, btn in enumerate(tokens_container):
193
+ btn.click(partial(run_interpretation, i=i), [progress_dummy, global_state, interpretation_prompt,
194
+ num_tokens, do_sample, temperature,
195
  top_k, top_p, repetition_penalty, length_penalty
196
  ], [*interpretation_bubbles])
197
 
198
  original_prompt_btn.click(get_hidden_states,
199
  [original_prompt_raw],
200
+ [progress_dummy, global_state, *tokens_container])
201
  demo.launch()