File size: 2,836 Bytes
aeb9733
 
 
 
 
151915d
 
aeb9733
151915d
 
 
aeb9733
151915d
aeb9733
 
 
 
02a709c
aeb9733
 
 
 
 
 
151915d
 
aeb9733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b1850a
0d9faa1
7578ae0
 
 
 
 
aeb9733
 
 
 
 
 
 
 
 
 
 
 
 
 
4f80be6
 
0d9faa1
4f80be6
0d9faa1
4f80be6
aeb9733
 
 
 
 
 
 
 
 
 
151915d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

import PIL
from PIL import Image

import torch
import torchvision
from torchvision import datasets, transforms

import vision_transformer as vits

arch = "vit_small"
mode = "simpool"
gamma = None
patch_size = 16
num_classes = 0
checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth"
checkpoint_key = "teacher"

cm = plt.get_cmap('viridis')
attn_map_size = 224
width_display = 290
height_display = 290

example_dir = "examples/"
example_list = [[example_dir + example] for example in os.listdir(example_dir)]
#example_list = "n03017168_54500.JPEG"

# Load model
model = vits.__dict__[arch](
            mode=mode,
            gamma=gamma,
            patch_size=patch_size,
            num_classes=num_classes, 
        )
state_dict = torch.load(checkpoint)
state_dict = state_dict[checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
msg = model.load_state_dict(state_dict, strict=True)

model.eval()

def get_attention_map(img, resolution=32):
    input_size = resolution * 14
    data_transforms = transforms.Compose([
        transforms.Resize((input_size, input_size), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    x = data_transforms(img)
    attn = model.get_simpool_attention(x[None, :, :, :])
    attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size)
    attn = attn/attn.sum()
    attn = attn.squeeze()
    attn = (attn-(attn).min())/((attn).max()-(attn).min())
    attn = torch.threshold(attn, 0.1, 0)

    attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB')
    attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST)
    return attn_img

attention_interface = gr.Interface(
    fn=get_attention_map,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Dropdown(choices=[16, 32, 64, 128],
                    label="Attention Map Resolution", 
                    value=32)
    ],
    outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display),
    examples=example_list,
    title="Explore the Attention Maps of SimPool🔍",
    description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision."
)

demo = gr.TabbedInterface([attention_interface],
                          ["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌")

if __name__ == "__main__":
    demo.launch(share=True)