kcz358 commited on
Commit
d6910a2
·
1 Parent(s): 18b266d

Allow shown auto interp explanation

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -5,6 +5,9 @@ from sae_auto_interp.features.features import upsample_mask
5
  import torch
6
  from transformers import AutoTokenizer
7
  from PIL import Image
 
 
 
8
  import spaces
9
 
10
  CITATION_BUTTON_TEXT = """
@@ -76,7 +79,7 @@ def generate_activations(image):
76
  result.scatter_(-1, topk.indices, topk.values)
77
  cached_list.append(result.detach().cpu())
78
  topk_indices = (
79
- latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
80
  )
81
 
82
  handles = [hooked_module.register_forward_hook(hook)]
@@ -91,9 +94,14 @@ def generate_activations(image):
91
  finally:
92
  for handle in handles:
93
  handle.remove()
 
 
 
 
 
94
 
95
  torch.cuda.empty_cache()
96
- return topk_indices, cached_list[0]
97
 
98
  def visualize_activations(image, feature_num, cached_tensor):
99
  base_img_tokens = 576
@@ -208,11 +216,12 @@ with gr.Blocks() as demo:
208
  with gr.Row():
209
  with gr.Column():
210
  image = gr.Image(type="pil", interactive=True, label="Sample Image")
211
- topk_features = gr.Textbox(value=topk_indices, placeholder="Top 100 Features", label="Top 100 Features")
 
212
  with gr.Row():
213
  clear_btn = gr.ClearButton([image, topk_features], value="Clear")
214
  submit_btn = gr.Button("Submit", variant="primary")
215
- submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features, cached_tensor])
216
  with gr.Column():
217
  output = gr.Image(label="Activation Visualization")
218
  feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
@@ -272,5 +281,14 @@ if __name__ == "__main__":
272
  hf_token=None
273
  )
274
  hooked_module = model.language_model.get_submodule("model.layers.24")
 
 
 
 
 
 
 
 
 
275
 
276
  demo.launch()
 
5
  import torch
6
  from transformers import AutoTokenizer
7
  from PIL import Image
8
+ from datasets import load_dataset
9
+ from tqdm import tqdm
10
+ import pandas as pd
11
  import spaces
12
 
13
  CITATION_BUTTON_TEXT = """
 
79
  result.scatter_(-1, topk.indices, topk.values)
80
  cached_list.append(result.detach().cpu())
81
  topk_indices = (
82
+ latents.squeeze(0).mean(dim=0).topk(k=200).indices.detach().cpu()
83
  )
84
 
85
  handles = [hooked_module.register_forward_hook(hook)]
 
94
  finally:
95
  for handle in handles:
96
  handle.remove()
97
+ examples = []
98
+ for indice in topk_indices:
99
+ if indice <= 5000:
100
+ examples.append([ f"model.layers.24_feature{indice.item()}",explanations[f"model.layers.24_feature{indice.item()}"]])
101
+
102
 
103
  torch.cuda.empty_cache()
104
+ return topk_indices, cached_list[0], examples
105
 
106
  def visualize_activations(image, feature_num, cached_tensor):
107
  base_img_tokens = 576
 
216
  with gr.Row():
217
  with gr.Column():
218
  image = gr.Image(type="pil", interactive=True, label="Sample Image")
219
+ topk_features = gr.Textbox(value=topk_indices, placeholder="Top 200 Features", label="Top 100 Features", max_lines=5)
220
+ known_explanation = gr.DataFrame(headers=["Feature", "Auto Interp Explanation"], label="Auto Interp Explanations")
221
  with gr.Row():
222
  clear_btn = gr.ClearButton([image, topk_features], value="Clear")
223
  submit_btn = gr.Button("Submit", variant="primary")
224
+ submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features, cached_tensor, known_explanation])
225
  with gr.Column():
226
  output = gr.Image(label="Activation Visualization")
227
  feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
 
281
  hf_token=None
282
  )
283
  hooked_module = model.language_model.get_submodule("model.layers.24")
284
+ dataset = load_dataset("lmms-lab/llava-sae-explanations-5k", "legacy", split="test")
285
+ dataset = dataset.remove_columns(["top1", "top2", "top3", "top4", "top5"])
286
+ print("Loading Explanation")
287
+ explanations = {}
288
+ pbar = tqdm(total=len(dataset), desc="Loading Explanation")
289
+ for da in dataset:
290
+ explanations[da["feature"]] = da["explanations"]
291
+ pbar.update(1)
292
+ pbar.close()
293
 
294
  demo.launch()