Update app.py
Browse files
app.py
CHANGED
@@ -87,7 +87,7 @@ def get_hidden_states(local_state, raw_original_prompt):
|
|
87 |
progress_dummy_output = ''
|
88 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
89 |
local_state.hidden_states = hidden_states.cpu().detach()
|
90 |
-
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
91 |
|
92 |
|
93 |
@spaces.GPU
|
@@ -217,18 +217,19 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
217 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
218 |
) for i in range(MAX_NUM_LAYERS)]
|
219 |
|
|
|
|
|
220 |
|
221 |
-
# event listeners
|
222 |
-
|
223 |
for i, btn in enumerate(tokens_container):
|
224 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
225 |
num_tokens, do_sample, temperature,
|
226 |
top_k, top_p, repetition_penalty, length_penalty
|
227 |
], [progress_dummy, *interpretation_bubbles])
|
228 |
-
|
229 |
original_prompt_btn.click(get_hidden_states,
|
230 |
-
[
|
231 |
-
[progress_dummy, *tokens_container, *interpretation_bubbles])
|
232 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
233 |
|
234 |
extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]
|
|
|
87 |
progress_dummy_output = ''
|
88 |
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
89 |
local_state.hidden_states = hidden_states.cpu().detach()
|
90 |
+
return [progress_dummy_output, local_state, *token_btns, *invisible_bubbles]
|
91 |
|
92 |
|
93 |
@spaces.GPU
|
|
|
217 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
218 |
) for i in range(MAX_NUM_LAYERS)]
|
219 |
|
220 |
+
|
221 |
+
local_state = gr.State(global_state.local_state)
|
222 |
|
223 |
+
# event listeners
|
|
|
224 |
for i, btn in enumerate(tokens_container):
|
225 |
btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
|
226 |
num_tokens, do_sample, temperature,
|
227 |
top_k, top_p, repetition_penalty, length_penalty
|
228 |
], [progress_dummy, *interpretation_bubbles])
|
229 |
+
|
230 |
original_prompt_btn.click(get_hidden_states,
|
231 |
+
[local_state, original_prompt_raw],
|
232 |
+
[progress_dummy, local_state, *tokens_container, *interpretation_bubbles])
|
233 |
original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
|
234 |
|
235 |
extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]
|