import os import PIL import ast import cv2 import json import torch import pickle import torchvision import numpy as np import gradio as gr from PIL import Image from typing import Tuple, Dict import matplotlib.pyplot as plt from timeit import default_timer as timer from torchvision import datasets, transforms import vision_transformer as vits ''' import warnings warnings.filterwarnings('ignore') example_list = [["examples/" + example] for example in os.listdir("examples")] with open('labels/imagenet1k-simple-labels.json') as f: class_names = json.load(f) from model import VisionTransformer from capture_weights import vit_weights ''' arch = "vit_small" mode = "simpool" gamma = None patch_size = 16 input_size = 224 num_classes = 0 checkpoint = "checkpoints/vits_dino_simpool_no_gamma_ep100.pth" checkpoint_key = "teacher" cm = plt.get_cmap('viridis') attn_map_size = 224 width_display = 300 height_display = 300 example_dir = "examples/" example_list = [[example_dir + example] for example in os.listdir(example_dir)] #example_list = "n03017168_54500.JPEG" # Load model model = vits.__dict__[arch]( mode=mode, gamma=gamma, patch_size=patch_size, num_classes=num_classes, ) state_dict = torch.load(checkpoint) state_dict = state_dict[checkpoint_key] state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} state_dict = {k: v for k, v in state_dict.items() if k in model.state_dict()} msg = model.load_state_dict(state_dict, strict=True) model.eval() # Define transformations data_transforms = transforms.Compose([ transforms.Resize((input_size, input_size), interpolation=3), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), ]) def get_attention_map(img): x = data_transforms(img) attn = model.get_simpool_attention(x[None, :, :, :]) attn = attn.reshape(1, 1, input_size//patch_size, input_size//patch_size) attn = attn/attn.sum() attn = attn.squeeze() attn = (attn-(attn).min())/((attn).max()-(attn).min()) attn = torch.threshold(attn, 0.1, 0) attn_img = Image.fromarray(np.uint8(cm(attn.detach().numpy())*255)).convert('RGB') attn_img = attn_img.resize((attn_map_size, attn_map_size), resample=Image.NEAREST) return attn_img attention_interface = gr.Interface( fn=get_attention_map, inputs=[gr.Image(type="pil", label="Input Image")], outputs=gr.Image(type="pil", label="SimPool Attention Map", width=width_display, height=height_display), examples=example_list, title="Explore the Attention Maps of SimPool🔍", description="Upload or use one of the selected images to explore the intricate focus areas of a ViT-S model with SimPool, trained on ImageNet-1k, under supervision." ) demo = gr.TabbedInterface([attention_interface], ["Visualize Attention Maps"], title="SimPool Attention Map Visualizer 🌌") if __name__ == "__main__": demo.launch()