Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer, AutoModelForCausa
|
|
13 |
from interpret import InterpretationPrompt
|
14 |
|
15 |
MAX_PROMPT_TOKENS = 60
|
|
|
16 |
|
17 |
|
18 |
## info
|
@@ -102,7 +103,7 @@ def get_hidden_states(raw_original_prompt):
|
|
102 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
103 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
104 |
progress_dummy_output = ''
|
105 |
-
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(
|
106 |
global_state.hidden_states = hidden_states
|
107 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
108 |
|
@@ -136,9 +137,9 @@ def run_interpretation(raw_interpretation_prompt, max_new_tokens, do_sample,
|
|
136 |
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
137 |
generation_texts = tokenizer.batch_decode(generated)
|
138 |
progress_dummy_output = ''
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
|
143 |
|
144 |
## main
|
@@ -235,7 +236,7 @@ with gr.Blocks(theme=gr.themes.Default(), css='styles.css') as demo:
|
|
235 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
236 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
|
237 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
238 |
-
) for i in range(
|
239 |
|
240 |
|
241 |
# event listeners
|
|
|
13 |
from interpret import InterpretationPrompt
|
14 |
|
15 |
MAX_PROMPT_TOKENS = 60
|
16 |
+
MAX_NUM_LAYERS = 50
|
17 |
|
18 |
|
19 |
## info
|
|
|
103 |
token_btns = ([gr.Button(token, visible=True) for token in tokens]
|
104 |
+ [gr.Button('', visible=False) for _ in range(MAX_PROMPT_TOKENS - len(tokens))])
|
105 |
progress_dummy_output = ''
|
106 |
+
invisible_bubbles = [gr.Textbox('', visible=False) for i in range(MAX_NUM_LAYERS)]
|
107 |
global_state.hidden_states = hidden_states
|
108 |
return [progress_dummy_output, *token_btns, *invisible_bubbles]
|
109 |
|
|
|
137 |
generated = interpretation_prompt.generate(global_state.model, {0: interpreted_vectors}, k=3, **generation_kwargs)
|
138 |
generation_texts = tokenizer.batch_decode(generated)
|
139 |
progress_dummy_output = ''
|
140 |
+
bubble_outputs = [gr.Textbox(text.replace('\n', ' '), visible=True, container=False, label=f'Layer {i}') for text in generation_texts]
|
141 |
+
bubble_outputs += [gr.Textbox(visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
|
142 |
+
return [progress_dummy_output, *bubble_outputs]
|
143 |
|
144 |
|
145 |
## main
|
|
|
236 |
progress_dummy = gr.Markdown('', elem_id='progress_dummy')
|
237 |
interpretation_bubbles = [gr.Textbox('', container=False, visible=False,
|
238 |
elem_classes=['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble']
|
239 |
+
) for i in range(MAX_NUM_LAYERS)]
|
240 |
|
241 |
|
242 |
# event listeners
|