dar-tau commited on
Commit
c23388b
1 Parent(s): 9136f03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -74,8 +74,7 @@ def reset_model(model_name, *extra_components, with_extra_components=True):
74
  + [*extra_components])
75
 
76
 
77
- @spaces.GPU
78
- def get_hidden_states(local_state, raw_original_prompt):
79
  model, tokenizer = global_state.model, global_state.tokenizer
80
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
81
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
@@ -86,8 +85,8 @@ def get_hidden_states(local_state, raw_original_prompt):
86
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
87
  progress_dummy_output = ''
88
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
89
- local_state.hidden_states = hidden_states.cpu().detach()
90
- return [progress_dummy_output, local_state, *token_btns, *invisible_bubbles]
91
 
92
 
93
  @spaces.GPU
@@ -216,10 +215,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
216
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
217
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
218
  ) for i in range(MAX_NUM_LAYERS)]
219
-
220
-
221
- local_state = gr.State(global_state.local_state)
222
-
223
  # event listeners
224
  for i, btn in enumerate(tokens_container):
225
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
@@ -228,8 +224,8 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
228
  ], [progress_dummy, *interpretation_bubbles])
229
 
230
  original_prompt_btn.click(get_hidden_states,
231
- [local_state, original_prompt_raw],
232
- [progress_dummy, local_state, *tokens_container, *interpretation_bubbles])
233
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
234
 
235
  extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]
 
74
  + [*extra_components])
75
 
76
 
77
+ def get_hidden_states(raw_original_prompt):
 
78
  model, tokenizer = global_state.model, global_state.tokenizer
79
  original_prompt = global_state.original_prompt_template.format(prompt=raw_original_prompt)
80
  model_inputs = tokenizer(original_prompt, add_special_tokens=False, return_tensors="pt").to(model.device)
 
85
  + [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
86
  progress_dummy_output = ''
87
  invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
88
+ global_state.local_state.hidden_states = hidden_states.cpu().detach()
89
+ return [progress_dummy_output, *token_btns, *invisible_bubbles]
90
 
91
 
92
  @spaces.GPU
 
215
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
216
  elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
217
  ) for i in range(MAX_NUM_LAYERS)]
218
+
 
 
 
219
  # event listeners
220
  for i, btn in enumerate(tokens_container):
221
  btn.click(partial(run_interpretation, i=i), [interpretation_prompt,
 
224
  ], [progress_dummy, *interpretation_bubbles])
225
 
226
  original_prompt_btn.click(get_hidden_states,
227
+ [original_prompt_raw],
228
+ [progress_dummy, *tokens_container, *interpretation_bubbles])
229
  original_prompt_raw.change(lambda: [gr.Button(visible=False) for _ in range(MAX_PROMPT_TOKENS)], [], tokens_container)
230
 
231
  extra_components = [interpretation_prompt, original_prompt_raw, original_prompt_btn]