dar-tau commited on
Commit
9136f03
1 Parent(s): 8f43d2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -87,7 +87,7 @@ def get_hidden_states(local_state, raw_original_prompt):
87
  progress_dummy_output = ''
88
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
89
  local_state.hidden_states = hidden_states.cpu().detach()
90
- return [progress_dummy_output, *token_btns, *invisible_bubbles]
91
 
92
 
93
  @spaces.GPU
@@ -217,18 +217,19 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
217
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
218
  ) for i in range(MAX_NUM_LAYERS)]
219
 
 
 
220
 
221
- # event listeners
222
-
223
  for i, btn in enumerate(tokens_container):
224
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
225
  num_tokens, do_sample, temperature,
226
  top_k, top_p, repetition_penalty, length_penalty
227
  ], [progress_dummy, *interpretation_bubbles])
228
-
229
  original_prompt_btn.click(get_hidden_states,
230
- [gr.State(global_state.local_state), original_prompt_raw],
231
- [progress_dummy, *tokens_container, *interpretation_bubbles])
232
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
233
 
234
  extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]
 
87
  progress_dummy_output = ''
88
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
89
  local_state.hidden_states = hidden_states.cpu().detach()
90
+ return [progress_dummy_output, local_state, *token_btns, *invisible_bubbles]
91
 
92
 
93
  @spaces.GPU
 
217
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
218
  ) for i in range(MAX_NUM_LAYERS)]
219
 
220
+
221
+ local_state = gr.State(global_state.local_state)
222
 
223
+ # event listeners
 
224
  for i, btn in enumerate(tokens_container):
225
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
226
  num_tokens, do_sample, temperature,
227
  top_k, top_p, repetition_penalty, length_penalty
228
  ], [progress_dummy, *interpretation_bubbles])
229
+
230
  original_prompt_btn.click(get_hidden_states,
231
+ [local_state, original_prompt_raw],
232
+ [progress_dummy, local_state, *tokens_container, *interpretation_bubbles])
233
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
234
 
235
  extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]