Update app.py
Browse files
app.py
CHANGED
@@ -53,12 +53,13 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
|
|
53 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
54 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
55 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
56 |
-
|
|
|
57 |
|
58 |
|
59 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
60 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
61 |
-
num_beams=1
|
62 |
|
63 |
interpreted_vectors = global_state[:, i]
|
64 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
@@ -82,7 +83,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
82 |
# generate the interpretations
|
83 |
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
84 |
generation_texts = tokenizer.batch_decode(generated)
|
85 |
-
|
|
|
86 |
|
87 |
|
88 |
## main
|
@@ -182,15 +184,18 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
182 |
for i in range(MAX_PROMPT_TOKENS):
|
183 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
184 |
tokens_container.append(btn)
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
188 |
for i, btn in enumerate(tokens_container):
|
189 |
-
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
|
|
190 |
top_k, top_p, repetition_penalty, length_penalty
|
191 |
], [*interpretation_bubbles])
|
192 |
|
193 |
original_prompt_btn.click(get_hidden_states,
|
194 |
[original_prompt_raw],
|
195 |
-
[global_state, *tokens_container])
|
196 |
demo.launch()
|
|
|
53 |
hidden_states = torch.stack([h.squeeze(0).cpu().detach() for h in outputs.hidden_states], dim=0)
|
54 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
55 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
56 |
+
progress_dummy_output = ''
|
57 |
+
return [progress_dummy_output, hidden_states, *token_btns]
|
58 |
|
59 |
|
60 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
61 |
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
62 |
+
num_beams=1):
|
63 |
|
64 |
interpreted_vectors = global_state[:, i]
|
65 |
length_penalty = -length_penalty # unintuitively, length_penalty > 0 will make sequences longer, so we negate it
|
|
|
83 |
# generate the interpretations
|
84 |
generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
85 |
generation_texts = tokenizer.batch_decode(generated)
|
86 |
+
progress_dummy_output = ''
|
87 |
+
return [progress_dummy_output] + [gr.Textbox(text, visible=True, container=False) for text in generation_texts]
|
88 |
|
89 |
|
90 |
## main
|
|
|
184 |
for i in range(MAX_PROMPT_TOKENS):
|
185 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
186 |
tokens_container.append(btn)
|
187 |
+
|
188 |
+
progress_dummy = gr.Text('', container=False)
|
189 |
+
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble'])
|
190 |
+
for i in range(model.config.num_hidden_layers)]
|
191 |
+
|
192 |
for i, btn in enumerate(tokens_container):
|
193 |
+
btn.click(partial(run_interpretation, i=i), [progress_dummy, global_state, interpretation_prompt,
|
194 |
+
num_tokens, do_sample, temperature,
|
195 |
top_k, top_p, repetition_penalty, length_penalty
|
196 |
], [*interpretation_bubbles])
|
197 |
|
198 |
original_prompt_btn.click(get_hidden_states,
|
199 |
[original_prompt_raw],
|
200 |
+
[progress_dummy, global_state, *tokens_container])
|
201 |
demo.launch()
|