dar-tau commited on
Commit
63981db
1 Parent(s): 7f2e668

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -13,7 +13,7 @@ MAX_PROMPT_TOKENS = 30
13
 
14
  ## info
15
  dataset_info = [{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
16
- {'name': 'Fact Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate',
17
  'filter': lambda x: x['label'] == 1},
18
  ]
19
 
@@ -66,7 +66,7 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
66
 
67
 
68
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
69
- temperature, top_k, top_p, repetition_penalty, length_penalty, i,
70
  num_beams=1):
71
 
72
  interpreted_vectors = global_state[:, i]
@@ -89,7 +89,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
89
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
90
 
91
  # generate the interpretations
92
- generated = interpretation_prompt.generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
 
93
  generation_texts = tokenizer.batch_decode(generated)
94
  progress_dummy_output = ''
95
  return ([progress_dummy_output] +
@@ -223,7 +224,7 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
223
  for i in range(MAX_PROMPT_TOKENS):
224
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
225
  tokens_container.append(btn)
226
-
227
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
228
 
229
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble',
@@ -255,7 +256,8 @@ with gr.Blocks(theme=gr.themes.Default(), css=css) as demo:
255
  for i, btn in enumerate(tokens_container):
256
  btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
257
  num_tokens, do_sample, temperature,
258
- top_k, top_p, repetition_penalty, length_penalty
 
259
  ], [progress_dummy, *interpretation_bubbles])
260
 
261
  original_prompt_btn.click(get_hidden_states,
 
13
 
14
  ## info
15
  dataset_info = [{'name': 'Commonsense', 'hf_repo': 'tau/commonsense_qa', 'text_col': 'question'},
16
+ {'name': 'Factual Recall', 'hf_repo': 'azhx/counterfact-filtered-gptj6b', 'text_col': 'subject+predicate',
17
  'filter': lambda x: x['label'] == 1},
18
  ]
19
 
 
66
 
67
 
68
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
69
+ temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
70
  num_beams=1):
71
 
72
  interpreted_vectors = global_state[:, i]
 
89
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
90
 
91
  # generate the interpretations
92
+ generate = spaces.GPU(interpretation_prompt.generate) if use_gpu else interpretation_prompt.generate
93
+ generated = generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
94
  generation_texts = tokenizer.batch_decode(generated)
95
  progress_dummy_output = ''
96
  return ([progress_dummy_output] +
 
224
  for i in range(MAX_PROMPT_TOKENS):
225
  btn = gr.Button('', visible=False, elem_classes=['token_btn'])
226
  tokens_container.append(btn)
227
+ use_gpu = gr.Checkbox(value=True, label='Use GPU')
228
  progress_dummy = gr.Markdown('', elem_id='progress_dummy')
229
 
230
  interpretation_bubbles = [gr.Textbox('', container=False, visible=False, elem_classes=['bubble',
 
256
  for i, btn in enumerate(tokens_container):
257
  btn.click(partial(run_interpretation, i=i), [global_state, interpretation_prompt,
258
  num_tokens, do_sample, temperature,
259
+ top_k, top_p, repetition_penalty, length_penalty,
260
+ use_gpu
261
  ], [progress_dummy, *interpretation_bubbles])
262
 
263
  original_prompt_btn.click(get_hidden_states,