Update app.py
Browse files
app.py
CHANGED
@@ -66,8 +66,8 @@ def get_hidden_states(raw_original_prompt, progress=gr.Progress()):
|
|
66 |
|
67 |
|
68 |
@spaces.GPU
|
69 |
-
def generate_interpretation_gpu(interpret_prompt, **kwargs):
|
70 |
-
return interpret_prompt.generate(**kwargs)
|
71 |
|
72 |
|
73 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
@@ -94,7 +94,8 @@ def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens,
|
|
94 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
95 |
|
96 |
# generate the interpretations
|
97 |
-
generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt,
|
|
|
98 |
generated = generate(interpretation_prompt, model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
99 |
generation_texts = tokenizer.batch_decode(generated)
|
100 |
progress_dummy_output = ''
|
|
|
66 |
|
67 |
|
68 |
@spaces.GPU
|
69 |
+
def generate_interpretation_gpu(interpret_prompt, *args, **kwargs):
|
70 |
+
return interpret_prompt.generate(*args, **kwargs)
|
71 |
|
72 |
|
73 |
def run_interpretation(global_state, raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
|
94 |
interpretation_prompt = InterpretationPrompt(tokenizer, interpretation_prompt)
|
95 |
|
96 |
# generate the interpretations
|
97 |
+
generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args,
|
98 |
+
**kwargs: interpretation_prompt.generate(*args, **kwargs)
|
99 |
generated = generate(interpretation_prompt, model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
100 |
generation_texts = tokenizer.batch_decode(generated)
|
101 |
progress_dummy_output = ''
|