Update app.py
Browse files
app.py
CHANGED
@@ -131,13 +131,13 @@ def run_interpretation(raw_original_prompt, raw_interpretation_prompt, max_new_t
|
|
131 |
generation_texts = tokenizer.batch_decode(generated)
|
132 |
|
133 |
# create GUI output
|
134 |
-
|
135 |
-
important_idxs
|
136 |
progress_dummy_output = ''
|
137 |
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
|
138 |
([] if i in important_idxs else ['faded_bubble']) for i in range(len(generation_texts))]
|
139 |
bubble_outputs = [gr.Textbox(text.replace('\n', ' '), show_label=True, visible=True,
|
140 |
-
container=
|
141 |
for i, text in enumerate(generation_texts)]
|
142 |
bubble_outputs += [gr.Textbox('', visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
|
143 |
return [progress_dummy_output, *bubble_outputs]
|
|
|
131 |
generation_texts = tokenizer.batch_decode(generated)
|
132 |
|
133 |
# create GUI output
|
134 |
+
important_idxs = 1 + interpreted_vectors.diff(dim=0).topk(k=int(np.ceil(0.2 * len(generation_texts))), dim=0).indices.cpu().numpy()
|
135 |
+
print(f'{important_idxs=}')
|
136 |
progress_dummy_output = ''
|
137 |
elem_classes = [['bubble', 'even_bubble' if i % 2 == 0 else 'odd_bubble'] +
|
138 |
([] if i in important_idxs else ['faded_bubble']) for i in range(len(generation_texts))]
|
139 |
bubble_outputs = [gr.Textbox(text.replace('\n', ' '), show_label=True, visible=True,
|
140 |
+
container=True, label=f'Layer {i}', elem_classes=elem_classes[i])
|
141 |
for i, text in enumerate(generation_texts)]
|
142 |
bubble_outputs += [gr.Textbox('', visible=False) for _ in range(MAX_NUM_LAYERS - len(bubble_outputs))]
|
143 |
return [progress_dummy_output, *bubble_outputs]
|