dar-tau commited on
Commit
997caf4
·
verified ·
1 Parent(s): df803aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -66,8 +66,8 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
66
 
67
 
68
  @spaces.GPU
69
- def generate_interpretation_gpu(interpret_prompt, **kwargs):
70
- return interpret_prompt.generate(**kwargs)
71
 
72
 
73
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
@@ -94,7 +94,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
94
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
95
 
96
  # generate the interpretations
97
- generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, **kwargs: interpretation_prompt.generate(**kwargs)
 
98
  generated = generate(interpretation_prompt, model, {0: interpreted_vectors}, k=3, **generation_kwargs)
99
  generation_texts = tokenizer.batch_decode(generated)
100
  progress_dummy_output = ''
 
66
 
67
 
68
  @spaces.GPU
69
+ def generate_interpretation_gpu(interpret_prompt, *args, **kwargs):
70
+ return interpret_prompt.generate(*args, **kwargs)
71
 
72
 
73
  def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
 
94
  interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
95
 
96
  # generate the interpretations
97
+ generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args,
98
+ **kwargs: interpretation_prompt.generate(*args, **kwargs)
99
  generated = generate(interpretation_prompt, model, {0: interpreted_vectors}, k=3, **generation_kwargs)
100
  generation_texts = tokenizer.batch_decode(generated)
101
  progress_dummy_output = ''