File size: 4,166 Bytes
5345282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
from transformers import ViTFeatureExtractor, ViTModel
import torch
import matplotlib.pyplot as plt
from skops import hub_utils
from einops import reduce
import seaborn as sns
import pickle

labels = [
    'tench',
    'English springer',
    'cassette player',
    'chain saw',
    'church',
    'French horn',
    'garbage truck',
    'gas pump',
    'golf ball',
    'parachute'
]

# load DINO
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)

# load logistic regression
!mkdir emb-gam-dino
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')

with open('emb-gam-dino/model.pkl', 'rb') as file: 
  logistic_regression = pickle.load(file)

def classify_and_heatmap(input_img):
  # get patch embeddings
  inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()}
  with torch.no_grad():
    patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
  
  # get scores
  scores = dict(zip(
      labels,
      logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0]
  ))

  # make plot
  num_patches_side = model.config.image_size // model.config.patch_size

  # set up figure
  fig, axs = plt.subplots(2, 6, figsize=(12, 5))
  gs = axs[0, 0].get_gridspec()
  for ax in axs[:, 0]:
      ax.remove()
  ax_orig_img = fig.add_subplot(gs[:, 0])

  # plot original image
  img = feature_extractor.to_pil_image(
      inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1)
  )
  ax_orig_img.imshow(img)
  ax_orig_img.axis('off')

  # plot patch contributions
  patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
  vmin = patch_contributions.min()
  vmax = patch_contributions.max()

  # print(len(list(axs[:, 1:].flat)))
  for i, ax in enumerate(axs[:, 1:].flat):
    sns.heatmap(
        patch_contributions[i].reshape(num_patches_side, num_patches_side),
        ax=ax,
        square=True,
        vmin=vmin,
        vmax=vmax,
    )
    ax.set_title(labels[i])
    ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}')
    ax.set_xticks([])
    ax.set_yticks([])

  return scores, plt

description='''
This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label.
'''

article='''
Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam).

Citation for stuff involved (not our papers):
```bibtex
@article{singh2022emb,
  title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
  author={Singh, Chandan and Gao, Jianfeng},
  journal={arXiv preprint arXiv:2209.11799},
  year={2022}
}

@InProceedings{Caron_2021_ICCV,
    author    = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
    title     = {Emerging Properties in Self-Supervised Vision Transformers},
    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {9650-9660}
}

@misc{imagenette,
  author = {fast.ai},
  title = {Imagenette},
  url = {https://github.com/fastai/imagenette},
}
```
'''

demo = gr.Interface(
    fn=classify_and_heatmap,
    inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'),
    outputs=[
        gr.Label(label='Class'),
        gr.Plot(label='Patch Contributions')
    ],
    title='Emb-GAM DINO',
    description=description,
    article=article,
    examples=['./examples/english_springer.png', './examples/golf_ball.png']
)
demo.launch(debug=True)