Spaces:
Running
on
Zero
Running
on
Zero
Allow shown auto interp explanation
Browse files
app.py
CHANGED
@@ -5,6 +5,9 @@ from sae_auto_interp.features.features import upsample_mask
|
|
5 |
import torch
|
6 |
from transformers import AutoTokenizer
|
7 |
from PIL import Image
|
|
|
|
|
|
|
8 |
import spaces
|
9 |
|
10 |
CITATION_BUTTON_TEXT = """
|
@@ -76,7 +79,7 @@ def generate_activations(image):
|
|
76 |
result.scatter_(-1, topk.indices, topk.values)
|
77 |
cached_list.append(result.detach().cpu())
|
78 |
topk_indices = (
|
79 |
-
latents.squeeze(0).mean(dim=0).topk(k=
|
80 |
)
|
81 |
|
82 |
handles = [hooked_module.register_forward_hook(hook)]
|
@@ -91,9 +94,14 @@ def generate_activations(image):
|
|
91 |
finally:
|
92 |
for handle in handles:
|
93 |
handle.remove()
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
torch.cuda.empty_cache()
|
96 |
-
return topk_indices, cached_list[0]
|
97 |
|
98 |
def visualize_activations(image, feature_num, cached_tensor):
|
99 |
base_img_tokens = 576
|
@@ -208,11 +216,12 @@ with gr.Blocks() as demo:
|
|
208 |
with gr.Row():
|
209 |
with gr.Column():
|
210 |
image = gr.Image(type="pil", interactive=True, label="Sample Image")
|
211 |
-
topk_features = gr.Textbox(value=topk_indices, placeholder="Top
|
|
|
212 |
with gr.Row():
|
213 |
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
|
214 |
submit_btn = gr.Button("Submit", variant="primary")
|
215 |
-
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features, cached_tensor])
|
216 |
with gr.Column():
|
217 |
output = gr.Image(label="Activation Visualization")
|
218 |
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
|
@@ -272,5 +281,14 @@ if __name__ == "__main__":
|
|
272 |
hf_token=None
|
273 |
)
|
274 |
hooked_module = model.language_model.get_submodule("model.layers.24")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
|
276 |
demo.launch()
|
|
|
5 |
import torch
|
6 |
from transformers import AutoTokenizer
|
7 |
from PIL import Image
|
8 |
+
from datasets import load_dataset
|
9 |
+
from tqdm import tqdm
|
10 |
+
import pandas as pd
|
11 |
import spaces
|
12 |
|
13 |
CITATION_BUTTON_TEXT = """
|
|
|
79 |
result.scatter_(-1, topk.indices, topk.values)
|
80 |
cached_list.append(result.detach().cpu())
|
81 |
topk_indices = (
|
82 |
+
latents.squeeze(0).mean(dim=0).topk(k=200).indices.detach().cpu()
|
83 |
)
|
84 |
|
85 |
handles = [hooked_module.register_forward_hook(hook)]
|
|
|
94 |
finally:
|
95 |
for handle in handles:
|
96 |
handle.remove()
|
97 |
+
examples = []
|
98 |
+
for indice in topk_indices:
|
99 |
+
if indice <= 5000:
|
100 |
+
examples.append([ f"model.layers.24_feature{indice.item()}",explanations[f"model.layers.24_feature{indice.item()}"]])
|
101 |
+
|
102 |
|
103 |
torch.cuda.empty_cache()
|
104 |
+
return topk_indices, cached_list[0], examples
|
105 |
|
106 |
def visualize_activations(image, feature_num, cached_tensor):
|
107 |
base_img_tokens = 576
|
|
|
216 |
with gr.Row():
|
217 |
with gr.Column():
|
218 |
image = gr.Image(type="pil", interactive=True, label="Sample Image")
|
219 |
+
topk_features = gr.Textbox(value=topk_indices, placeholder="Top 200 Features", label="Top 100 Features", max_lines=5)
|
220 |
+
known_explanation = gr.DataFrame(headers=["Feature", "Auto Interp Explanation"], label="Auto Interp Explanations")
|
221 |
with gr.Row():
|
222 |
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
|
223 |
submit_btn = gr.Button("Submit", variant="primary")
|
224 |
+
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features, cached_tensor, known_explanation])
|
225 |
with gr.Column():
|
226 |
output = gr.Image(label="Activation Visualization")
|
227 |
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
|
|
|
281 |
hf_token=None
|
282 |
)
|
283 |
hooked_module = model.language_model.get_submodule("model.layers.24")
|
284 |
+
dataset = load_dataset("lmms-lab/llava-sae-explanations-5k", "legacy", split="test")
|
285 |
+
dataset = dataset.remove_columns(["top1", "top2", "top3", "top4", "top5"])
|
286 |
+
print("Loading Explanation")
|
287 |
+
explanations = {}
|
288 |
+
pbar = tqdm(total=len(dataset), desc="Loading Explanation")
|
289 |
+
for da in dataset:
|
290 |
+
explanations[da["feature"]] = da["explanations"]
|
291 |
+
pbar.update(1)
|
292 |
+
pbar.close()
|
293 |
|
294 |
demo.launch()
|