kcz358 commited on
Commit
18b266d
·
1 Parent(s): 88f8ab3

Fix cached error

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -50,10 +50,12 @@ happy_file_path = "assets/happy.jpg"
50
  def generate_activations(image):
51
  prompt = "<image>"
52
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
53
- global cached_tensor, topk_indices
 
 
54
 
55
  def hook(module: torch.nn.Module, _, outputs):
56
- global cached_tensor, topk_indices
57
  # Maybe unpack tuple outputs
58
  if isinstance(outputs, tuple):
59
  unpack_outputs = list(outputs)
@@ -72,7 +74,7 @@ def generate_activations(image):
72
  result = torch.zeros_like(latents)
73
  # results (bs, seq, num_latents)
74
  result.scatter_(-1, topk.indices, topk.values)
75
- cached_tensor = result.detach().cpu()
76
  topk_indices = (
77
  latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
78
  )
@@ -91,10 +93,9 @@ def generate_activations(image):
91
  handle.remove()
92
 
93
  torch.cuda.empty_cache()
94
- return topk_indices
95
 
96
- def visualize_activations(image, feature_num):
97
- global cached_tensor
98
  base_img_tokens = 576
99
  patch_size = 24
100
  # Using Cached tensor
@@ -191,6 +192,7 @@ def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history
191
 
192
 
193
  with gr.Blocks() as demo:
 
194
  gr.Markdown(
195
  """
196
  # Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
@@ -210,12 +212,12 @@ with gr.Blocks() as demo:
210
  with gr.Row():
211
  clear_btn = gr.ClearButton([image, topk_features], value="Clear")
212
  submit_btn = gr.Button("Submit", variant="primary")
213
- submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features])
214
  with gr.Column():
215
  output = gr.Image(label="Activation Visualization")
216
  feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
217
  visualize_btn = gr.Button("Visualize", variant="primary")
218
- visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output])
219
 
220
  dummy_text = gr.Textbox(visible=False, label="Explanation")
221
  gr.Examples(
@@ -261,7 +263,6 @@ with gr.Blocks() as demo:
261
 
262
 
263
  if __name__ == "__main__":
264
- cached_tensor = None
265
  tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
266
  sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
267
  model, processor = maybe_load_llava_model(
 
50
  def generate_activations(image):
51
  prompt = "<image>"
52
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
53
+ global topk_indices
54
+
55
+ cached_list = []
56
 
57
  def hook(module: torch.nn.Module, _, outputs):
58
+ global topk_indices
59
  # Maybe unpack tuple outputs
60
  if isinstance(outputs, tuple):
61
  unpack_outputs = list(outputs)
 
74
  result = torch.zeros_like(latents)
75
  # results (bs, seq, num_latents)
76
  result.scatter_(-1, topk.indices, topk.values)
77
+ cached_list.append(result.detach().cpu())
78
  topk_indices = (
79
  latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
80
  )
 
93
  handle.remove()
94
 
95
  torch.cuda.empty_cache()
96
+ return topk_indices, cached_list[0]
97
 
98
+ def visualize_activations(image, feature_num, cached_tensor):
 
99
  base_img_tokens = 576
100
  patch_size = 24
101
  # Using Cached tensor
 
192
 
193
 
194
  with gr.Blocks() as demo:
195
+ cached_tensor = gr.State()
196
  gr.Markdown(
197
  """
198
  # Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
 
212
  with gr.Row():
213
  clear_btn = gr.ClearButton([image, topk_features], value="Clear")
214
  submit_btn = gr.Button("Submit", variant="primary")
215
+ submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features, cached_tensor])
216
  with gr.Column():
217
  output = gr.Image(label="Activation Visualization")
218
  feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
219
  visualize_btn = gr.Button("Visualize", variant="primary")
220
+ visualize_btn.click(visualize_activations, inputs=[image, feature_num, cached_tensor], outputs=[output])
221
 
222
  dummy_text = gr.Textbox(visible=False, label="Explanation")
223
  gr.Examples(
 
263
 
264
 
265
  if __name__ == "__main__":
 
266
  tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
267
  sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
268
  model, processor = maybe_load_llava_model(