Update app.py
Browse files
app.py
CHANGED
@@ -13,7 +13,7 @@ MAX_PROMPT_TOKENS = 30
|
|
13 |
|
14 |
## info
|
15 |
dataset_info = [{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
|
16 |
-
{'name': '
|
17 |
'filter': lambda x: x['label'] == 1},
|
18 |
]
|
19 |
|
@@ -66,7 +66,7 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
|
|
66 |
|
67 |
|
68 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
69 |
-
temperature, top_k, top_p, repetition_penalty, length_penalty, i,
|
70 |
num_beams=1):
|
71 |
|
72 |
interpreted_vectors = global_state[:, i]
|
@@ -89,7 +89,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
89 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
90 |
|
91 |
# generate the interpretations
|
92 |
-
|
|
|
93 |
generation_texts = tokenizer.batch_decode(generated)
|
94 |
progress_dummy_output = ''
|
95 |
return ([progress_dummy_output] +
|
@@ -223,7 +224,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
223 |
for i in range(MAX_PROMPT_TOKENS):
|
224 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
225 |
tokens_container.append(btn)
|
226 |
-
|
227 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
228 |
|
229 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble',
|
@@ -255,7 +256,8 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
|
|
255 |
for i, btn in enumerate(tokens_container):
|
256 |
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
257 |
num_tokens, do_sample, temperature,
|
258 |
-
top_k, top_p, repetition_penalty, length_penalty
|
|
|
259 |
], [progress_dummy, *interpretation_bubbles])
|
260 |
|
261 |
original_prompt_btn.click(get_hidden_states,
|
|
|
13 |
|
14 |
## info
|
15 |
dataset_info = [{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
|
16 |
+
{'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate',
|
17 |
'filter': lambda x: x['label'] == 1},
|
18 |
]
|
19 |
|
|
|
66 |
|
67 |
|
68 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
69 |
+
temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
|
70 |
num_beams=1):
|
71 |
|
72 |
interpreted_vectors = global_state[:, i]
|
|
|
89 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
90 |
|
91 |
# generate the interpretations
|
92 |
+
generate = spaces.GPU(interpretation_prompt.generate) if use_gpu else interpretation_prompt.generate
|
93 |
+
generated = generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
94 |
generation_texts = tokenizer.batch_decode(generated)
|
95 |
progress_dummy_output = ''
|
96 |
return ([progress_dummy_output] +
|
|
|
224 |
for i in range(MAX_PROMPT_TOKENS):
|
225 |
btn = gr.Button('', visible=False, elem_classes=['token_btn'])
|
226 |
tokens_container.append(btn)
|
227 |
+
use_gpu = gr.Checkbox(value=True, label='Use GPU')
|
228 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
229 |
|
230 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble',
|
|
|
256 |
for i, btn in enumerate(tokens_container):
|
257 |
btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
|
258 |
num_tokens, do_sample, temperature,
|
259 |
+
top_k, top_p, repetition_penalty, length_penalty,
|
260 |
+
use_gpu
|
261 |
], [progress_dummy, *interpretation_bubbles])
|
262 |
|
263 |
original_prompt_btn.click(get_hidden_states,
|