emb-gam-dino / app.py
patrickramos's picture
Update app.py
22d82a8
import gradio as gr
import torch
from transformers import ViTFeatureExtractor, ViTModel
from skops import hub_utils
from einops import reduce
from torchvision.transforms.functional import to_pil_image
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
labels = [
'tench',
'English springer',
'cassette player',
'chain saw',
'church',
'French horn',
'garbage truck',
'gas pump',
'golf ball',
'parachute'
]
# load DINO
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16')
model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device)
# load logistic regression
os.mkdir('emb-gam-dino')
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino')
with open('emb-gam-dino/model.pkl', 'rb') as file:
logistic_regression = pickle.load(file)
def classify_and_heatmap(input_img):
# get patch embeddings
inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()}
with torch.no_grad():
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
# get scores
scores = dict(zip(
labels,
logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0]
))
# make plot
num_patches_side = model.config.image_size // model.config.patch_size
# set up figure
fig, axs = plt.subplots(2, 6, figsize=(12, 5))
gs = axs[0, 0].get_gridspec()
for ax in axs[:, 0]:
ax.remove()
ax_orig_img = fig.add_subplot(gs[:, 0])
# plot original image
img = to_pil_image(
inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1)
)
ax_orig_img.imshow(img)
ax_orig_img.axis('off')
# plot patch contributions
patch_contributions = (
logistic_regression.coef_ \
@ patch_embeddings.T.numpy() \
+ logistic_regression.intercept_.reshape(-1, 1) / (num_patches_side ** 2)
).reshape(-1, num_patches_side, num_patches_side)
vmin = patch_contributions.min()
vmax = patch_contributions.max()
# print(len(list(axs[:, 1:].flat)))
for i, ax in enumerate(axs[:, 1:].flat):
sns.heatmap(
patch_contributions[i].reshape(num_patches_side, num_patches_side),
ax=ax,
square=True,
vmin=vmin,
vmax=vmax,
)
ax.set_title(labels[i])
ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}')
ax.set_xticks([])
ax.set_yticks([])
return scores, plt
description='''
This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label.
'''
article='''
Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam).
Citation for stuff involved (not our papers):
```bibtex
@article{singh2022emb,
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
author={Singh, Chandan and Gao, Jianfeng},
journal={arXiv preprint arXiv:2209.11799},
year={2022}
}
@InProceedings{Caron_2021_ICCV,
author = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand},
title = {Emerging Properties in Self-Supervised Vision Transformers},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2021},
pages = {9650-9660}
}
@misc{imagenette,
author = {fast.ai},
title = {Imagenette},
url = {https://github.com/fastai/imagenette},
}
```
'''
demo = gr.Interface(
fn=classify_and_heatmap,
inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'),
outputs=[
gr.Label(label='Class'),
gr.Plot(label='Patch Contributions')
],
title='Emb-GAM DINO',
description=description,
article=article,
examples=['./examples/english_springer.png', './examples/golf_ball.png']
)
demo.launch(debug=True)