dar-tau commited on
Commit
d8c5a8d
1 Parent(s): 98858d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -1
app.py CHANGED
@@ -65,6 +65,11 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
65
  return [progress_dummy_output, hidden_states, *token_btns]
66
 
67
 
 
 
 
 
 
68
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
69
  temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
70
  num_beams=1):
@@ -89,7 +94,7 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
89
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
90
 
91
  # generate the interpretations
92
- generate = spaces.GPU(interpretation_prompt.generate) if use_gpu else interpretation_prompt.generate
93
  generated = generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
94
  generation_texts = tokenizer.batch_decode(generated)
95
  progress_dummy_output = ''
 
65
  return [progress_dummy_output, hidden_states, *token_btns]
66
 
67
 
68
+ @spaces.GPU
69
+ def generate_interpretation_gpu(interpret_prompt, **kwargs):
70
+ return interpret_prompt.generate(**kwargs)
71
+
72
+
73
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
74
  temperature, top_k, top_p, repetition_penalty, length_penalty, use_gpu, i,
75
  num_beams=1):
 
94
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
95
 
96
  # generate the interpretations
97
+ generate = generate_interpretation_gpu if use_gpu else lambda lambda interpretation_prompt, **kwargs: interpretation_prompt.generate(**kwargs)
98
  generated = generate(model, {0: interpreted_vectors}, k=3, **generation_kwargs)
99
  generation_texts = tokenizer.batch_decode(generated)
100
  progress_dummy_output = ''