hyesulim commited on
Commit
ee02e77
·
verified ·
1 Parent(s): 1198e0c

test: try to improve efficiency

Browse files
Files changed (1) hide show
  1. app.py +216 -29
app.py CHANGED
@@ -372,61 +372,233 @@ def load_all_data(image_root, pkl_root):
372
  return data_dict, sae_data_dict
373
 
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
376
  default_image_name = "christmas-imagenet"
 
 
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
 
379
  with gr.Blocks(
380
  theme=gr.themes.Citrus(),
381
  css="""
382
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
383
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
384
- """,
385
  ) as demo:
386
  with gr.Row():
387
  with gr.Column():
388
- # Left View: Image selection and click handling
389
  gr.Markdown("## Select input image and patch on the image")
 
 
390
  image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
391
  image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
392
 
393
- # Update image display when a new image is selected
 
 
 
394
  image_selector.change(
395
- fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
 
 
396
  )
 
 
397
  image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
398
 
399
  with gr.Column():
400
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
401
- model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
402
  model_selector = gr.Dropdown(
403
- choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
 
 
404
  )
405
- init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
406
  neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
407
 
 
 
 
 
 
 
408
  image_selector.change(
409
- fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
410
  )
411
  image_display.select(
412
- fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
 
 
 
 
 
413
  )
414
- model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
415
  model_selector.change(
416
- fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
 
 
417
  )
418
 
419
  with gr.Row():
420
  with gr.Column():
421
- radio_names = get_init_radio_options(default_image_name, model_options[0])
422
-
423
- feautre_idx = radio_names[0].split("-")[-1]
424
  markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
425
- init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
426
 
427
  gr.Markdown("### Localize SAE latent activation using CLIP")
428
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
429
- init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
430
  gr.Markdown("### Localize SAE latent activation using MaPLE")
431
  seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
432
 
@@ -434,12 +606,17 @@ with gr.Blocks(
434
  gr.Markdown("## Top activating SAE latent index")
435
 
436
  radio_choices = gr.Radio(
437
- choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
 
 
 
438
  )
 
439
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
440
 
441
  markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
442
 
 
443
  gr.Markdown("### ImageNet")
444
  top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
445
  act_value_1 = gr.Markdown(init_values[0])
@@ -452,18 +629,31 @@ with gr.Blocks(
452
  top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
453
  act_value_3 = gr.Markdown(init_values[2])
454
 
 
 
 
 
 
455
  image_display.select(
456
- fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
 
457
  )
458
-
459
  model_selector.change(
460
- fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
 
461
  )
462
-
463
  image_selector.select(
464
- fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
 
 
 
465
  )
466
 
 
467
  radio_choices.change(
468
  fn=update_markdown,
469
  inputs=[radio_choices],
@@ -471,6 +661,7 @@ with gr.Blocks(
471
  queue=True,
472
  )
473
 
 
474
  radio_choices.change(
475
  fn=show_activation_heatmap_clip,
476
  inputs=[image_selector, radio_choices, toggle_btn],
@@ -478,6 +669,7 @@ with gr.Blocks(
478
  queue=True,
479
  )
480
 
 
481
  radio_choices.change(
482
  fn=show_activation_heatmap_maple,
483
  inputs=[image_selector, radio_choices, model_selector],
@@ -485,13 +677,7 @@ with gr.Blocks(
485
  queue=True,
486
  )
487
 
488
- # toggle_btn.change(
489
- # fn=get_top_images,
490
- # inputs=[radio_choices, toggle_btn],
491
- # outputs=[top_image_1, top_image_2, top_image_3],
492
- # queue=True,
493
- # )
494
-
495
  toggle_btn.change(
496
  fn=show_activation_heatmap_clip,
497
  inputs=[image_selector, radio_choices, toggle_btn],
@@ -501,3 +687,4 @@ with gr.Blocks(
501
 
502
  # Launch the app
503
  demo.launch()
 
 
372
  return data_dict, sae_data_dict
373
 
374
 
375
+ # data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
376
+ # default_image_name = "christmas-imagenet"
377
+
378
+
379
+ # with gr.Blocks(
380
+ # theme=gr.themes.Citrus(),
381
+ # css="""
382
+ # .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
383
+ # .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
384
+ # """,
385
+ # ) as demo:
386
+ # with gr.Row():
387
+ # with gr.Column():
388
+ # # Left View: Image selection and click handling
389
+ # gr.Markdown("## Select input image and patch on the image")
390
+ # image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
391
+ # image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
392
+
393
+ # # Update image display when a new image is selected
394
+ # image_selector.change(
395
+ # fn=lambda img_name: data_dict[img_name]["image"], inputs=image_selector, outputs=image_display
396
+ # )
397
+ # image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
398
+
399
+ # with gr.Column():
400
+ # gr.Markdown("## SAE latent activations of CLIP and MaPLE")
401
+ # model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
402
+ # model_selector = gr.Dropdown(
403
+ # choices=model_options, value=model_options[0], label="Select adapted model (MaPLe)"
404
+ # )
405
+ # init_plot = plot_activation_distribution(None, default_image_name, model_options[0])
406
+ # neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
407
+
408
+ # image_selector.change(
409
+ # fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
410
+ # )
411
+ # image_display.select(
412
+ # fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
413
+ # )
414
+ # model_selector.change(fn=load_image, inputs=[image_selector], outputs=image_display)
415
+ # model_selector.change(
416
+ # fn=plot_activation_distribution, inputs=[image_selector, model_selector], outputs=neuron_plot
417
+ # )
418
+
419
+ # with gr.Row():
420
+ # with gr.Column():
421
+ # radio_names = get_init_radio_options(default_image_name, model_options[0])
422
+
423
+ # feautre_idx = radio_names[0].split("-")[-1]
424
+ # markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
425
+ # init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
426
+
427
+ # gr.Markdown("### Localize SAE latent activation using CLIP")
428
+ # seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
429
+ # init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], model_options[0])
430
+ # gr.Markdown("### Localize SAE latent activation using MaPLE")
431
+ # seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
432
+
433
+ # with gr.Column():
434
+ # gr.Markdown("## Top activating SAE latent index")
435
+
436
+ # radio_choices = gr.Radio(
437
+ # choices=radio_names, label="Top activating SAE latent", interactive=True, value=radio_names[0]
438
+ # )
439
+ # toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
440
+
441
+ # markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
442
+
443
+ # gr.Markdown("### ImageNet")
444
+ # top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
445
+ # act_value_1 = gr.Markdown(init_values[0])
446
+
447
+ # gr.Markdown("### ImageNet-Sketch")
448
+ # top_image_2 = gr.Image(value=init_tops[1], type="pil", label="ImageNet-Sketch", show_label=False)
449
+ # act_value_2 = gr.Markdown(init_values[1])
450
+
451
+ # gr.Markdown("### Caltech101")
452
+ # top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
453
+ # act_value_3 = gr.Markdown(init_values[2])
454
+
455
+ # image_display.select(
456
+ # fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
457
+ # )
458
+
459
+ # model_selector.change(
460
+ # fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
461
+ # )
462
+
463
+ # image_selector.select(
464
+ # fn=update_radio_options, inputs=[image_selector, model_selector], outputs=[radio_choices], queue=True
465
+ # )
466
+
467
+ # radio_choices.change(
468
+ # fn=update_markdown,
469
+ # inputs=[radio_choices],
470
+ # outputs=[markdown_display, markdown_display_2],
471
+ # queue=True,
472
+ # )
473
+
474
+ # radio_choices.change(
475
+ # fn=show_activation_heatmap_clip,
476
+ # inputs=[image_selector, radio_choices, toggle_btn],
477
+ # outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
478
+ # queue=True,
479
+ # )
480
+
481
+ # radio_choices.change(
482
+ # fn=show_activation_heatmap_maple,
483
+ # inputs=[image_selector, radio_choices, model_selector],
484
+ # outputs=[seg_mask_display_maple],
485
+ # queue=True,
486
+ # )
487
+
488
+ # # toggle_btn.change(
489
+ # # fn=get_top_images,
490
+ # # inputs=[radio_choices, toggle_btn],
491
+ # # outputs=[top_image_1, top_image_2, top_image_3],
492
+ # # queue=True,
493
+ # # )
494
+
495
+ # toggle_btn.change(
496
+ # fn=show_activation_heatmap_clip,
497
+ # inputs=[image_selector, radio_choices, toggle_btn],
498
+ # outputs=[seg_mask_display, top_image_1, top_image_2, top_image_3, act_value_1, act_value_2, act_value_3],
499
+ # queue=True,
500
+ # )
501
+
502
+ # # Launch the app
503
+ # demo.launch()
504
+
505
+ # Precompute all necessary data and store in caches before launching the Gradio app.
506
+
507
+ # Load data once at startup
508
  data_dict, sae_data_dict = load_all_data(image_root="./data/image", pkl_root=pkl_root)
509
  default_image_name = "christmas-imagenet"
510
+ model_options = [f"MaPLE-{dataset_name}" for dataset_name in DATASET_LIST]
511
+ default_model = model_options[0]
512
 
513
+ # Precompute activation distributions for all images/models to avoid repeated I/O.
514
+ activation_cache = {}
515
+ for img_name in data_dict.keys():
516
+ for mdl in ["CLIP"] + model_options:
517
+ activation_cache[(img_name, mdl)] = get_activation_distribution(img_name, mdl)
518
+
519
+ # Precompute initial radio options and top-neuron related info for default states.
520
+ radio_names = get_init_radio_options(default_image_name, default_model)
521
+ feautre_idx = radio_names[0].split("-")[-1]
522
+
523
+ # Precompute initial figures and mask overlays so they don't need to be recomputed on load.
524
+ init_plot = plot_activation_distribution(None, default_image_name, default_model)
525
+ init_seg, init_tops, init_values = show_activation_heatmap(default_image_name, radio_names[0], "CLIP")
526
+ init_seg_maple, _, _ = show_activation_heatmap(default_image_name, radio_names[0], default_model)
527
 
528
  with gr.Blocks(
529
  theme=gr.themes.Citrus(),
530
  css="""
531
  .image-row .gr-image { margin: 0 !important; padding: 0 !important; }
532
  .image-row img { width: auto; height: 50px; } /* Set a uniform height for all images */
533
+ """
534
  ) as demo:
535
  with gr.Row():
536
  with gr.Column():
 
537
  gr.Markdown("## Select input image and patch on the image")
538
+
539
+ # Instead of recomputing, just directly load from data_dict
540
  image_selector = gr.Dropdown(choices=list(data_dict.keys()), value=default_image_name, label="Select Image")
541
  image_display = gr.Image(value=data_dict[default_image_name]["image"], type="pil", interactive=True)
542
 
543
+ # When image changes, just display the corresponding image
544
+ def update_image_display(img_name):
545
+ return data_dict[img_name]["image"]
546
+
547
  image_selector.change(
548
+ fn=update_image_display,
549
+ inputs=image_selector,
550
+ outputs=image_display
551
  )
552
+
553
+ # Highlight selected grid cell
554
  image_display.select(fn=highlight_grid, inputs=[image_selector], outputs=[image_display])
555
 
556
  with gr.Column():
557
  gr.Markdown("## SAE latent activations of CLIP and MaPLE")
558
+
559
  model_selector = gr.Dropdown(
560
+ choices=model_options,
561
+ value=default_model,
562
+ label="Select adapted model (MaPLe)"
563
  )
564
+
565
  neuron_plot = gr.Plot(label="Neuron Activation", value=init_plot, show_label=False)
566
 
567
+ # Update plot based on image/model
568
+ def update_plot(img_name, model_name):
569
+ # Use precomputed activation distributions from activation_cache
570
+ # to create the figure. If figure creation is expensive, consider caching plots as well.
571
+ return plot_activation_distribution(None, img_name, model_name)
572
+
573
  image_selector.change(
574
+ fn=update_plot,
575
+ inputs=[image_selector, model_selector],
576
+ outputs=neuron_plot
577
  )
578
  image_display.select(
579
+ fn=update_plot,
580
+ inputs=[image_selector, model_selector],
581
+ outputs=neuron_plot
582
+ )
583
+ model_selector.change(
584
+ fn=lambda img_name: data_dict[img_name]["image"],
585
+ inputs=[image_selector],
586
+ outputs=image_display
587
  )
 
588
  model_selector.change(
589
+ fn=update_plot,
590
+ inputs=[image_selector, model_selector],
591
+ outputs=neuron_plot
592
  )
593
 
594
  with gr.Row():
595
  with gr.Column():
596
+ # Use previously precomputed segmentation masks and tops instead of recomputing on load
 
 
597
  markdown_display = gr.Markdown(f"## Segmentation mask for the selected SAE latent - {feautre_idx}")
 
598
 
599
  gr.Markdown("### Localize SAE latent activation using CLIP")
600
  seg_mask_display = gr.Image(value=init_seg, type="pil", show_label=False)
601
+
602
  gr.Markdown("### Localize SAE latent activation using MaPLE")
603
  seg_mask_display_maple = gr.Image(value=init_seg_maple, type="pil", show_label=False)
604
 
 
606
  gr.Markdown("## Top activating SAE latent index")
607
 
608
  radio_choices = gr.Radio(
609
+ choices=radio_names,
610
+ label="Top activating SAE latent",
611
+ interactive=True,
612
+ value=radio_names[0]
613
  )
614
+
615
  toggle_btn = gr.Checkbox(label="Show segmentation mask", value=False)
616
 
617
  markdown_display_2 = gr.Markdown(f"## Top reference images for the selected SAE latent - {feautre_idx}")
618
 
619
+ # Display precomputed top images and values
620
  gr.Markdown("### ImageNet")
621
  top_image_1 = gr.Image(value=init_tops[0], type="pil", label="ImageNet", show_label=False)
622
  act_value_1 = gr.Markdown(init_values[0])
 
629
  top_image_3 = gr.Image(value=init_tops[2], type="pil", label="Caltech101", show_label=False)
630
  act_value_3 = gr.Markdown(init_values[2])
631
 
632
+ # Update radio choices when image/model changes.
633
+ # If expensive, this could be cached as well.
634
+ def on_image_or_model_change(img_name, model_name):
635
+ return update_radio_options(None, img_name, model_name)
636
+
637
  image_display.select(
638
+ fn=on_image_or_model_change,
639
+ inputs=[image_selector, model_selector],
640
+ outputs=[radio_choices],
641
+ queue=True
642
  )
 
643
  model_selector.change(
644
+ fn=on_image_or_model_change,
645
+ inputs=[image_selector, model_selector],
646
+ outputs=[radio_choices],
647
+ queue=True
648
  )
 
649
  image_selector.select(
650
+ fn=on_image_or_model_change,
651
+ inputs=[image_selector, model_selector],
652
+ outputs=[radio_choices],
653
+ queue=True
654
  )
655
 
656
+ # Update markdown titles dynamically based on selected radio choice
657
  radio_choices.change(
658
  fn=update_markdown,
659
  inputs=[radio_choices],
 
661
  queue=True,
662
  )
663
 
664
+ # Show activation heatmap for CLIP
665
  radio_choices.change(
666
  fn=show_activation_heatmap_clip,
667
  inputs=[image_selector, radio_choices, toggle_btn],
 
669
  queue=True,
670
  )
671
 
672
+ # Show activation heatmap for MaPLE
673
  radio_choices.change(
674
  fn=show_activation_heatmap_maple,
675
  inputs=[image_selector, radio_choices, model_selector],
 
677
  queue=True,
678
  )
679
 
680
+ # Toggle segmentation mask
 
 
 
 
 
 
681
  toggle_btn.change(
682
  fn=show_activation_heatmap_clip,
683
  inputs=[image_selector, radio_choices, toggle_btn],
 
687
 
688
  # Launch the app
689
  demo.launch()
690
+