File size: 4,182 Bytes
5345282
 
5df45f9
5345282
 
5df45f9
5345282
5df45f9
5345282
5df45f9
5345282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5df45f9
5345282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTModel
from skops import hub_utils
from einops import reduce
import matplotlib.pyplot as plt
import seaborn as sns

import pickle
import os

labels = [
    'tench',
    'English springer',
    'cassette player',
    'chain saw',
    'church',
    'French horn',
    'garbage truck',
    'gas pump',
    'golf ball',
    'parachute'
]

# load DINO
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)

# load logistic regression
os.mkdir('emb-gam-dino')
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')

with open('emb-gam-dino/model.pkl', 'rb') as file: 
  logistic_regression = pickle.load(file)

def classify_and_heatmap(input_img):
  # get patch embeddings
  inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()}
  with torch.no_grad():
    patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
  
  # get scores
  scores = dict(zip(
      labels,
      logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0]
  ))

  # make plot
  num_patches_side = model.config.image_size // model.config.patch_size

  # set up figure
  fig, axs = plt.subplots(2, 6, figsize=(12, 5))
  gs = axs[0, 0].get_gridspec()
  for ax in axs[:, 0]:
      ax.remove()
  ax_orig_img = fig.add_subplot(gs[:, 0])

  # plot original image
  img = feature_extractor.to_pil_image(
      inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1)
  )
  ax_orig_img.imshow(img)
  ax_orig_img.axis('off')

  # plot patch contributions
  patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
  vmin = patch_contributions.min()
  vmax = patch_contributions.max()

  # print(len(list(axs[:, 1:].flat)))
  for i, ax in enumerate(axs[:, 1:].flat):
    sns.heatmap(
        patch_contributions[i].reshape(num_patches_side, num_patches_side),
        ax=ax,
        square=True,
        vmin=vmin,
        vmax=vmax,
    )
    ax.set_title(labels[i])
    ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}')
    ax.set_xticks([])
    ax.set_yticks([])

  return scores, plt

description='''
This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label.
'''

article='''
Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam).

Citation for stuff involved (not our papers):
```bibtex
@article{singh2022emb,
  title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
  author={Singh, Chandan and Gao, Jianfeng},
  journal={arXiv preprint arXiv:2209.11799},
  year={2022}
}

@InProceedings{Caron_2021_ICCV,
    author    = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
    title     = {Emerging Properties in Self-Supervised Vision Transformers},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {9650-9660}
}

@misc{imagenette,
  author = {fast.ai},
  title = {Imagenette},
  url = {https://github.com/fastai/imagenette},
}
```
'''

demo = gr.Interface(
    fn=classify_and_heatmap,
    inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'),
    outputs=[
        gr.Label(label='Class'),
        gr.Plot(label='Patch Contributions')
    ],
    title='Emb-GAM DINO',
    description=description,
    article=article,
    examples=['./examples/english_springer.png', './examples/golf_ball.png']
)
demo.launch(debug=True)