Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import ViTFeatureExtractor, ViTModel | |
from skops import hub_utils | |
from einops import reduce | |
from torchvision.transforms.functional import to_pil_image | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import pickle | |
import os | |
labels = [ | |
'tench', | |
'English springer', | |
'cassette player', | |
'chain saw', | |
'church', | |
'French horn', | |
'garbage truck', | |
'gas pump', | |
'golf ball', | |
'parachute' | |
] | |
# load DINO | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vitb16') | |
model = ViTModel.from_pretrained('facebook/dino-vitb16').eval().to(device) | |
# load logistic regression | |
os.mkdir('emb-gam-dino') | |
hub_utils.download(repo_id='Ramos-Ramos/emb-gam-dino', dst='emb-gam-dino') | |
with open('emb-gam-dino/model.pkl', 'rb') as file: | |
logistic_regression = pickle.load(file) | |
def classify_and_heatmap(input_img): | |
# get patch embeddings | |
inputs = {k: v.to(device) for k, v in feature_extractor(input_img, return_tensors='pt').items()} | |
with torch.no_grad(): | |
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu() | |
# get scores | |
scores = dict(zip( | |
labels, | |
logistic_regression.predict_proba(reduce(patch_embeddings, 'p d -> () d', 'sum'))[0] | |
)) | |
# make plot | |
num_patches_side = model.config.image_size // model.config.patch_size | |
# set up figure | |
fig, axs = plt.subplots(2, 6, figsize=(12, 5)) | |
gs = axs[0, 0].get_gridspec() | |
for ax in axs[:, 0]: | |
ax.remove() | |
ax_orig_img = fig.add_subplot(gs[:, 0]) | |
# plot original image | |
img = to_pil_image( | |
inputs['pixel_values'].squeeze(0) * torch.tensor(feature_extractor.image_std).view(-1, 1, 1) + torch.tensor(feature_extractor.image_mean).view(-1, 1, 1) | |
) | |
ax_orig_img.imshow(img) | |
ax_orig_img.axis('off') | |
# plot patch contributions | |
patch_contributions = ( | |
logistic_regression.coef_ \ | |
+ logistic_regression.intercept_.reshape(-1, 1) / (num_patches_side ** 2) | |
).reshape(-1, num_patches_side, num_patches_side) | |
vmin = patch_contributions.min() | |
vmax = patch_contributions.max() | |
# print(len(list(axs[:, 1:].flat))) | |
for i, ax in enumerate(axs[:, 1:].flat): | |
sns.heatmap( | |
patch_contributions[i].reshape(num_patches_side, num_patches_side), | |
ax=ax, | |
square=True, | |
vmin=vmin, | |
vmax=vmax, | |
) | |
ax.set_title(labels[i]) | |
ax.set_xlabel(f'score={patch_contributions[i].sum():.2f}') | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
return scores, plt | |
description=''' | |
This demo is a simple extension of [Emb-GAM (Singh & Gao, 2022)](https://arxiv.org/abs/2209.11799) to images. It does image classification on [Imagenette](https://github.com/fastai/imagenette) and visualizes the contrbutions of each image patch to each label. | |
''' | |
article=''' | |
Under the hood, we use [DINO](https://arxiv.org/abs/2104.14294) to extract patch embeddings and a logistic regression model following the set up of the [offical Emb-GAM implementation](https://github.com/csinva/emb-gam). | |
Citation for stuff involved (not our papers): | |
```bibtex | |
@article{singh2022emb, | |
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models}, | |
author={Singh, Chandan and Gao, Jianfeng}, | |
journal={arXiv preprint arXiv:2209.11799}, | |
year={2022} | |
} | |
@InProceedings{Caron_2021_ICCV, | |
author = {Caron, Mathilde and Touvron, Hugo and Misra, Ishan and J\'egou, Herv\'e and Mairal, Julien and Bojanowski, Piotr and Joulin, Armand}, | |
title = {Emerging Properties in Self-Supervised Vision Transformers}, | |
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, | |
month = {October}, | |
year = {2021}, | |
pages = {9650-9660} | |
} | |
@misc{imagenette, | |
author = {fast.ai}, | |
title = {Imagenette}, | |
url = {https://github.com/fastai/imagenette}, | |
} | |
``` | |
''' | |
demo = gr.Interface( | |
fn=classify_and_heatmap, | |
inputs=gr.Image(shape=(224, 224), type='pil', label='Input Image'), | |
outputs=[ | |
gr.Label(label='Class'), | |
gr.Plot(label='Patch Contributions') | |
], | |
title='Emb-GAM DINO', | |
description=description, | |
article=article, | |
examples=['./examples/english_springer.png', './examples/golf_ball.png'] | |
) | |
demo.launch(debug=True) |