Spaces:
Running
on
Zero
Running
on
Zero
Fix cached error
Browse files
app.py
CHANGED
@@ -50,10 +50,12 @@ happy_file_path = "assets/happy.jpg"
|
|
50 |
def generate_activations(image):
|
51 |
prompt = "<image>"
|
52 |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
|
53 |
-
global
|
|
|
|
|
54 |
|
55 |
def hook(module: torch.nn.Module, _, outputs):
|
56 |
-
global
|
57 |
# Maybe unpack tuple outputs
|
58 |
if isinstance(outputs, tuple):
|
59 |
unpack_outputs = list(outputs)
|
@@ -72,7 +74,7 @@ def generate_activations(image):
|
|
72 |
result = torch.zeros_like(latents)
|
73 |
# results (bs, seq, num_latents)
|
74 |
result.scatter_(-1, topk.indices, topk.values)
|
75 |
-
|
76 |
topk_indices = (
|
77 |
latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
|
78 |
)
|
@@ -91,10 +93,9 @@ def generate_activations(image):
|
|
91 |
handle.remove()
|
92 |
|
93 |
torch.cuda.empty_cache()
|
94 |
-
return topk_indices
|
95 |
|
96 |
-
def visualize_activations(image, feature_num):
|
97 |
-
global cached_tensor
|
98 |
base_img_tokens = 576
|
99 |
patch_size = 24
|
100 |
# Using Cached tensor
|
@@ -191,6 +192,7 @@ def generate_with_clamp(feature_idx, feature_strength, text, image, chat_history
|
|
191 |
|
192 |
|
193 |
with gr.Blocks() as demo:
|
|
|
194 |
gr.Markdown(
|
195 |
"""
|
196 |
# Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
|
@@ -210,12 +212,12 @@ with gr.Blocks() as demo:
|
|
210 |
with gr.Row():
|
211 |
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
|
212 |
submit_btn = gr.Button("Submit", variant="primary")
|
213 |
-
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features])
|
214 |
with gr.Column():
|
215 |
output = gr.Image(label="Activation Visualization")
|
216 |
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
|
217 |
visualize_btn = gr.Button("Visualize", variant="primary")
|
218 |
-
visualize_btn.click(visualize_activations, inputs=[image, feature_num], outputs=[output])
|
219 |
|
220 |
dummy_text = gr.Textbox(visible=False, label="Explanation")
|
221 |
gr.Examples(
|
@@ -261,7 +263,6 @@ with gr.Blocks() as demo:
|
|
261 |
|
262 |
|
263 |
if __name__ == "__main__":
|
264 |
-
cached_tensor = None
|
265 |
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
266 |
sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
|
267 |
model, processor = maybe_load_llava_model(
|
|
|
50 |
def generate_activations(image):
|
51 |
prompt = "<image>"
|
52 |
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device)
|
53 |
+
global topk_indices
|
54 |
+
|
55 |
+
cached_list = []
|
56 |
|
57 |
def hook(module: torch.nn.Module, _, outputs):
|
58 |
+
global topk_indices
|
59 |
# Maybe unpack tuple outputs
|
60 |
if isinstance(outputs, tuple):
|
61 |
unpack_outputs = list(outputs)
|
|
|
74 |
result = torch.zeros_like(latents)
|
75 |
# results (bs, seq, num_latents)
|
76 |
result.scatter_(-1, topk.indices, topk.values)
|
77 |
+
cached_list.append(result.detach().cpu())
|
78 |
topk_indices = (
|
79 |
latents.squeeze(0).mean(dim=0).topk(k=100).indices.detach().cpu()
|
80 |
)
|
|
|
93 |
handle.remove()
|
94 |
|
95 |
torch.cuda.empty_cache()
|
96 |
+
return topk_indices, cached_list[0]
|
97 |
|
98 |
+
def visualize_activations(image, feature_num, cached_tensor):
|
|
|
99 |
base_img_tokens = 576
|
100 |
patch_size = 24
|
101 |
# Using Cached tensor
|
|
|
192 |
|
193 |
|
194 |
with gr.Blocks() as demo:
|
195 |
+
cached_tensor = gr.State()
|
196 |
gr.Markdown(
|
197 |
"""
|
198 |
# Large Multi-modal Models Can Interpret Features in Large Multi-modal Models
|
|
|
212 |
with gr.Row():
|
213 |
clear_btn = gr.ClearButton([image, topk_features], value="Clear")
|
214 |
submit_btn = gr.Button("Submit", variant="primary")
|
215 |
+
submit_btn.click(generate_activations, inputs=[image], outputs=[topk_features, cached_tensor])
|
216 |
with gr.Column():
|
217 |
output = gr.Image(label="Activation Visualization")
|
218 |
feature_num = gr.Slider(1, 131072, 1, 1, label="Feature Number", interactive=True)
|
219 |
visualize_btn = gr.Button("Visualize", variant="primary")
|
220 |
+
visualize_btn.click(visualize_activations, inputs=[image, feature_num, cached_tensor], outputs=[output])
|
221 |
|
222 |
dummy_text = gr.Textbox(visible=False, label="Explanation")
|
223 |
gr.Examples(
|
|
|
263 |
|
264 |
|
265 |
if __name__ == "__main__":
|
|
|
266 |
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llama3-llava-next-8b-hf")
|
267 |
sae = load_single_sae("lmms-lab/llama3-llava-next-8b-hf-sae-131k", "model.layers.24")
|
268 |
model, processor = maybe_load_llava_model(
|