Spaces:
Runtime error
Runtime error
File size: 4,381 Bytes
5345282 5df45f9 5345282 22d82a8 5df45f9 5345282 5df45f9 5345282 5df45f9 5345282 5df45f9 5345282 22d82a8 5345282 22d82a8 5345282 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTModel
from skops import hub_utils
from einops import reduce
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
labels = [
'tench',
'English springer',
'cassette player',
'chain saw',
'church',
'French horn',
'garbage truck',
'gas pump',
'golf ball',
'parachute'
]
# load DINO
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
# load logistic regression
os.mkdir('emb-gam-dino')
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
with open('emb-gam-dino/model.pkl', 'rb') as file:
logistic_regression = pickle.load(file)
def classify_and_heatmap(input_img):
# get patch embeddings
inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()}
with torch.no_grad():
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
# get scores
scores = dict(zip(
labels,
logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0]
))
# make plot
num_patches_side = model.config.image_size // model.config.patch_size
# set up figure
fig, axs = plt.subplots(2, 6, figsize=(12, 5))
gs = axs[0, 0].get_gridspec()
for ax in axs[:, 0]:
ax.remove()
ax_orig_img = fig.add_subplot(gs[:, 0])
# plot original image
img = to_pil_image(
inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1)
)
ax_orig_img.imshow(img)
ax_orig_img.axis('off')
# plot patch contributions
patch_contributions = (
logistic_regression.coef_ \
@ patch_embeddings.T.numpy() \
+ logistic_regression.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
).reshape(-1, num_patches_side, num_patches_side)
vmin = patch_contributions.min()
vmax = patch_contributions.max()
# print(len(list(axs[:, 1:].flat)))
for i, ax in enumerate(axs[:, 1:].flat):
sns.heatmap(
patch_contributions[i].reshape(num_patches_side, num_patches_side),
ax=ax,
square=True,
vmin=vmin,
vmax=vmax,
)
ax.set_title(labels[i])
ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}')
ax.set_xticks([])
ax.set_yticks([])
return scores, plt
description='''
This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label.
'''
article='''
Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam).
Citation for stuff involved (not our papers):
```bibtex
@article{singh2022emb,
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
author={Singh, Chandan and Gao, Jianfeng},
journal={arXiv preprint arXiv:2209.11799},
year={2022}
}
@InProceedings{Caron_2021_ICCV,
author = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
title = {Emerging Properties in Self-Supervised Vision Transformers},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {9650-9660}
}
@misc{imagenette,
author = {fast.ai},
title = {Imagenette},
url = {https://github.com/fastai/imagenette},
}
```
'''
demo = gr.Interface(
fn=classify_and_heatmap,
inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'),
outputs=[
gr.Label(label='Class'),
gr.Plot(label='Patch Contributions')
],
title='Emb-GAM DINO',
description=description,
article=article,
examples=['./examples/english_springer.png', './examples/golf_ball.png']
)
demo.launch(debug=True) |