Spaces:
Runtime error
Runtime error
File size: 3,077 Bytes
aeb9733 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import PIL
import ast
import cv2
import json
import torch
import pickle
import torchvision
import numpy as np
import gradio as gr
from PIL import Image
from typing import Tuple, Dict
import matplotlib.pyplot as plt
from timeit import default_timer as timer
from torchvision import datasets, transforms
import vision_transformer as vits
'''
import warnings
warnings.filterwarnings('ignore')
example_list = [["examples/" + example] for example in os.listdir("examples")]
with open('labels/imagenet1k-simple-labels.json') as f:
class_names = json.load(f)
from model import VisionTransformer
from capture_weights import vit_weights
'''
arch = "vit_small"
mode = "simpool"
gamma = None
patch_size = 16
input_size = 224
num_classes = 0
checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth"
checkpoint_key = "teacher"
cm = plt.get_cmap('viridis')
attn_map_size = 224
width_display = 300
height_display = 300
example_dir = "examples/"
example_list = [[example_dir + example] for example in os.listdir(example_dir)]
#example_list = "n03017168_54500.JPEG"
# Load model
model = vits.__dict__[arch](
mode=mode,
gamma=gamma,
patch_size=patch_size,
num_classes=num_classes,
)
state_dict = torch.load(checkpoint)
state_dict = state_dict[checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()}
msg = model.load_state_dict(state_dict, strict=True)
model.eval()
# Define transformations
data_transforms = transforms.Compose([
transforms.Resize((input_size, input_size), interpolation=3),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
def get_attention_map(img):
x = data_transforms(img)
attn = model.get_simpool_attention(x[None, :, :, :])
attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size)
attn = attn/attn.sum()
attn = attn.squeeze()
attn = (attn-(attn).min())/((attn).max()-(attn).min())
attn = torch.threshold(attn, 0.1, 0)
attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB')
attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST)
return attn_img
attention_interface = gr.Interface(
fn=get_attention_map,
inputs=[gr.Image(type="pil", label="Input Image")],
outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display),
examples=example_list,
title="Explore the Attention Maps of SimPool🔍",
description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision."
)
demo = gr.TabbedInterface([attention_interface],
["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌")
if __name__ == "__main__":
demo.launch() |