Spaces:
Runtime error
Runtime error
File size: 2,836 Bytes
aeb9733 151915d aeb9733 151915d aeb9733 151915d aeb9733 02a709c aeb9733 151915d aeb9733 2b1850a 0d9faa1 7578ae0 aeb9733 4f80be6 0d9faa1 4f80be6 0d9faa1 4f80be6 aeb9733 151915d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import PIL
from PIL import Image
import torch
import torchvision
from torchvision import datasets, transforms
import vision_transformer as vits
arch = "vit_small"
mode = "simpool"
gamma = None
patch_size = 16
num_classes = 0
checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth"
checkpoint_key = "teacher"
cm = plt.get_cmap('viridis')
attn_map_size = 224
width_display = 290
height_display = 290
example_dir = "examples/"
example_list = [[example_dir + example] for example in os.listdir(example_dir)]
#example_list = "n03017168_54500.JPEG"
# Load model
model = vits.__dict__[arch](
mode=mode,
gamma=gamma,
patch_size=patch_size,
num_classes=num_classes,
)
state_dict = torch.load(checkpoint)
state_dict = state_dict[checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
msg = model.load_state_dict(state_dict, strict=True)
model.eval()
def get_attention_map(img, resolution=32):
input_size = resolution * 14
data_transforms = transforms.Compose([
transforms.Resize((input_size, input_size), interpolation=3),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
x = data_transforms(img)
attn = model.get_simpool_attention(x[None, :, :, :])
attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size)
attn = attn/attn.sum()
attn = attn.squeeze()
attn = (attn-(attn).min())/((attn).max()-(attn).min())
attn = torch.threshold(attn, 0.1, 0)
attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB')
attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST)
return attn_img
attention_interface = gr.Interface(
fn=get_attention_map,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Dropdown(choices=[16, 32, 64, 128],
label="Attention Map Resolution",
value=32)
],
outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display),
examples=example_list,
title="Explore the Attention Maps of SimPool🔍",
description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision."
)
demo = gr.TabbedInterface([attention_interface],
["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌")
if __name__ == "__main__":
demo.launch(share=True) |