dar-tau commited on
Commit
2bb573c
1 Parent(s): 3f8ef2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -98,11 +98,12 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
98
 
99
  # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
100
  interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
101
- interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt, layers_format=global_state.layers_format)
102
 
103
  # generate the interpretations
104
  # generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
105
- generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
 
106
  generation_texts = global_state.tokenizer.batch_decode(generated)
107
  progress_dummy_output = ''
108
  bubble_outputs = [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
 
98
 
99
  # create an InterpretationPrompt object from raw_interpretation_prompt (after putting it in the right template)
100
  interpretation_prompt = global_state.interpretation_prompt_template.format(prompt=raw_interpretation_prompt, repeat=5)
101
+ interpretation_prompt = InterpretationPrompt(global_state.tokenizer, interpretation_prompt)
102
 
103
  # generate the interpretations
104
  # generate = generate_interpretation_gpu if use_gpu else lambda interpretation_prompt, *args, **kwargs: interpretation_prompt.generate(*args, **kwargs)
105
+ generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, layers_format=global_state.layers_format, k=3,
106
+ **generation_kwargs)
107
  generation_texts = global_state.tokenizer.batch_decode(generated)
108
  progress_dummy_output = ''
109
  bubble_outputs = [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]